cumulative_ops
jaxns.internals.cumulative_ops
Module Contents
- cumulative_op_static(op, init, xs, pre_op=False, unroll=1)[source]
Compute a cumulative operation on a list of values.
- Parameters:
- Returns:
the final accumulated value, and the result of the cumulative operation applied on input
- Return type:
Tuple[V, V]
- cumulative_op_dynamic(op, init, xs, stop_idx, pre_op=False, empty_fill=None)[source]
Compute a cumulative operation on a list of values with a dynamic stop index.
- Parameters:
op (Callable[[V, Y], V]) – the operation to perform
init (V) – the initial value
xs (Y) – the list of values
stop_idx (jaxns.internals.types.IntArray) – how many accumulations to perform
pre_op (bool) – if True, the operation is applied before the accumulation, so the first value is the initial value.
empty_fill (Optional[V]) – the value to fill the output with if the stop_idx is provided, else uses init
- Returns:
the final accumulated value, and the result of the cumulative operation applied on input
- Return type:
Tuple[V, V]