interp_utils

jaxns.internals.interp_utils

Module Contents

apply_interp(x, i0, alpha0, i1, alpha1, axis=0)[source]

Apply interpolation alpha given axis.

Parameters:
  • x (jax.Array) – nd-array

  • i0 (jax.Array) – [N] or scalar

  • alpha0 (jax.Array) – [N] or scalar

  • i1 (jax.Array) – [N] or scalar

  • alpha1 (jax) – [N] or scalar

  • axis (int) – axis to take along

Returns:

[N] or scalar interpolated along axis

left_broadcast_multiply(x, y, axis=0)[source]

Left broadcast multiply of two arrays. Equivalent to right-padding before multiply

Parameters:
  • x – […, a,b,c,…]

  • y – [a, b]

  • axis (int)

Returns:

[…, a, b, c, …]

get_interp_indices_and_weights(x, xp, regular_grid=False)[source]

One-dimensional linear interpolation. Outside bounds is also linear from nearest two points.

Parameters:
  • x – the x-coordinates at which to evaluate the interpolated values

  • xp – the x-coordinates of the data points, must be increasing

  • regular_grid (bool)

Returns:

the interpolated values, same shape as x

Return type:

Tuple[Tuple[Union[int, jax.Array, float, jax.Array]], Tuple[Union[int, jax.Array, float, jax.Array]]]

class InterpolatedArray[source]
x: jax.Array[source]
values: jax.Array[source]
axis: int = 0[source]
regular_grid: bool = False[source]
__post_init__()[source]
property shape[source]
__call__(time)[source]

Interpolate at time based on input times.

Parameters:

time (jax.Array) – time to evaluate at.

Returns:

value at given time

Return type:

jax.Array