maps

jaxns.internals.maps

Module Contents

logger[source]
replace_index(operand, update, start_index)[source]

Replaces an index or slice with an update. If update is too big to respect start_index then start_index is shifted, which will give non-intuitive results.

get_index(operand, start_index, length)[source]
prepare_func_args(f)[source]

Takes a callable(a,b,…,z=Z) and prepares it into callable(**kwargs), such that only a,b,…,z are taken from **kwargs and the rest ignored.

This allows f(**kwarg) to work even if f() is missing some keys from kwargs.

Parameters:

fcallable(a,b,...,z=Z)

Returns:

callable(**kwargs) where **kwargs are the filtered for args of the original function.

F[source]
FV[source]
chunked_pmap(f, chunk_size=None, unroll=1)[source]

A version of pmap which chunks the input into smaller pieces to avoid memory issues.

Parameters:
  • f (Callable[Ellipsis, FV]) – callable

  • chunk_size (Optional[int]) – the size of the chunks. Default is len(devices())

  • unroll (int) – the number of times to unroll the computation

Returns:

a chunked version of f

Return type:

Callable[Ellipsis, FV]

prepad(a, chunksize)[source]
Parameters:

chunksize (int) –

T[source]
remove_chunk_dim(py_tree)[source]

Remove the chunk dimension from a pytree

Parameters:

py_tree (T) – pytree to remove chunk dimension from

Returns:

pytree with chunk dimension removed

Return type:

T

add_chunk_dim(py_tree, chunk_size)[source]

Add a chunk dimension to a pytree

Parameters:
  • py_tree (T) – pytree to add chunk dimension to

  • chunk_size (int) – size of chunk dimension

Returns:

pytree with chunk dimension added

Return type:

T

chunked_vmap(f, chunk_size=None, unroll=1)[source]

A version of vmap which chunks the input into smaller pieces to avoid memory issues.

Parameters:
  • f – the function to be mapped

  • chunk_size (Optional[int]) – the size of the chunks. Default is len(devices())

  • unroll (int) – the number of times to unroll the computation

Returns: