prefix_sum ==================== .. py:module:: jaxns.internals.prefix_sum .. rubric:: :code:`jaxns.internals.prefix_sum` .. rubric:: Module Contents .. py:data:: X .. py:function:: scan_associative(fn, elems, axis = 0) Perform a scan with an associative binary operation, in parallel. Suitable for fn: (X, X) -> X where (f o f) o f == f o (f o f) The associative scan operation computes the cumulative sum, or [all-prefix sum](https://en.wikipedia.org/wiki/Prefix_sum), of a set of elements under an associative binary operation [1]. For example, using the ordinary addition operator `fn = lambda a, b: a + b`, this is equivalent to the ordinary cumulative sum. The associative structure allows the computation to be decomposed and executed by parallel reduction. Where a naive sequential implementation would loop over all `N` elements, this method requires only a logarithmic number (`2 * ceil(log_2 N)`) of sequential steps, and can thus yield substantial performance speedups from hardware-accelerated vectorization. The total number of invocations of the binary operation (including those performed in parallel) is `2 * (N / 2 + N / 4 + ... + 1) = 2N - 2` --- i.e., approximately twice as many as a naive approach. [1] Blelloch, Guy E. [Prefix sums and their applications]( https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf) Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990. :param fn: the associative binary operation to perform. :param elems: [..., N, ...] the input elements to scan over. :param axis: the axis to scan over. :returns: [..., N, ...] cumulative operation applied on input.