cumulative_ops

jaxns.internals.cumulative_ops

Module Contents

V[source]
Y[source]
cumulative_op_static(op, init, xs, pre_op=False, unroll=1)[source]

Compute a cumulative operation on a list of values.

Parameters:
  • op (Callable[[V, Y], V]) – the operation to perform

  • init (V) – the initial value

  • xs (Y) – the list of values

  • pre_op (bool) – if True, the operation is applied before the accumulation, so the first value is the initial value.

  • unroll (int) – how many iterations to unroll the loop at a time

Returns:

the final accumulated value, and the result of the cumulative operation applied on input

Return type:

Tuple[V, V]

cumulative_op_dynamic(op, init, xs, stop_idx, pre_op=False, empty_fill=None)[source]

Compute a cumulative operation on a list of values with a dynamic stop index.

Parameters:
  • op (Callable[[V, Y], V]) – the operation to perform

  • init (V) – the initial value

  • xs (Y) – the list of values

  • stop_idx (jaxns.internals.types.IntArray) – how many accumulations to perform

  • pre_op (bool) – if True, the operation is applied before the accumulation, so the first value is the initial value.

  • empty_fill (Optional[V]) – the value to fill the output with if the stop_idx is provided, else uses init

Returns:

the final accumulated value, and the result of the cumulative operation applied on input

Return type:

Tuple[V, V]