cumulative_ops
jaxns.internals.cumulative_ops
Module Contents
- scan_associative_cumulative_op(op, init, xs, pre_op=False)[source]
Compute a cumulative operation on an array of values using scan_associative.
- Parameters:
op (Callable[[X, X], X]) – the operation to perform, must be associative.
init (X) – the initial value.
xs (X) – the array of values.
pre_op (bool)
- Returns:
the final accumulated value, and the result of the cumulative operation applied on input
- Return type:
Tuple[X, X]
- cumulative_op_static(op, init, xs, pre_op=False, unroll=1)[source]
Compute a cumulative operation on a list of values.
- Parameters:
- Returns:
the final accumulated value, and the result of the cumulative operation applied on input
- Return type:
Tuple[V, V]
- cumulative_op_dynamic(op, init, xs, stop_idx, pre_op=False, empty_fill=None)[source]
Compute a cumulative operation on a list of values with a dynamic stop index.
- Parameters:
op (Callable[[V, Y], V]) – the operation to perform
init (V) – the initial value
xs (Y) – the list of values
stop_idx (jaxns.internals.types.IntArray) – how many accumulations to perform
pre_op (bool) – if True, the operation is applied before the accumulation, so the first value is the initial value.
empty_fill (Optional[V]) – the value to fill the output with if the stop_idx is provided, else uses init
- Returns:
the final accumulated value, and the result of the cumulative operation applied on input
- Return type:
Tuple[V, V]