prefix_sum

jaxns.internals.prefix_sum

Module Contents

X[source]
scan_associative(fn, elems, axis=0)[source]

Perform a scan with an associative binary operation, in parallel.

Suitable for fn: (X, X) -> X where (f o f) o f == f o (f o f)

The associative scan operation computes the cumulative sum, or [all-prefix sum](https://en.wikipedia.org/wiki/Prefix_sum), of a set of elements under an associative binary operation [1]. For example, using the ordinary addition operator fn = lambda a, b: a + b, this is equivalent to the ordinary cumulative sum.

The associative structure allows the computation to be decomposed and executed by parallel reduction. Where a naive sequential implementation would loop over all N elements, this method requires only a logarithmic number (2 * ceil(log_2 N)) of sequential steps, and can thus yield substantial performance speedups from hardware-accelerated vectorization. The total number of invocations of the binary operation (including those performed in parallel) is 2 * (N / 2 + N / 4 + … + 1) = 2N - 2 — i.e., approximately twice as many as a naive approach.

[1] Blelloch, Guy E.

[Prefix sums and their applications]( https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf) Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.

Parameters:
  • fn (Callable[[X, X], X]) – the associative binary operation to perform.

  • elems (X) – […, N, …] the input elements to scan over.

  • axis (int) – the axis to scan over.

Returns:

[…, N, …] cumulative operation applied on input.

Return type:

X