maps
jaxns.internals.maps
Module Contents
- replace_index(operand, update, start_index)[source]
Replaces an index or slice with an update. If update is too big to respect start_index then start_index is shifted, which will give non-intuitive results.
- prepare_func_args(f)[source]
Takes a callable(a,b,…,z=Z) and prepares it into
callable(**kwargs)
, such that only a,b,…,z are taken from**kwargs
and the rest ignored.This allows
f(**kwarg)
to work even iff()
is missing some keys from kwargs.- Parameters:
f –
callable(a,b,...,z=Z)
- Returns:
callable(**kwargs)
where**kwargs
are the filtered for args of the original function.
- chunked_pmap(f, chunk_size=None, unroll=1)[source]
A version of pmap which chunks the input into smaller pieces to avoid memory issues.
- remove_chunk_dim(py_tree)[source]
Remove the chunk dimension from a pytree
- Parameters:
py_tree (T) – pytree to remove chunk dimension from
- Returns:
pytree with chunk dimension removed
- Return type:
T
- add_chunk_dim(py_tree, chunk_size)[source]
Add a chunk dimension to a pytree
- Parameters:
py_tree (T) – pytree to add chunk dimension to
chunk_size (int) – size of chunk dimension
- Returns:
pytree with chunk dimension added
- Return type:
T