maps
jaxns.internals.maps
Module Contents
- replace_index(operand, update, start_index)[source]
Replaces an index or slice with an update. If update is too big to respect start_index then start_index is shifted, which will give non-intuitive results.
- prepare_func_args(f)[source]
Takes a callable(a,b,…,z=Z) and prepares it into
callable(**kwargs), such that only a,b,…,z are taken from**kwargsand the rest ignored.This allows
f(**kwarg)to work even iff()is missing some keys from kwargs.- Parameters:
f –
callable(a,b,...,z=Z)- Returns:
callable(**kwargs)where**kwargsare the filtered for args of the original function.
- chunked_pmap(f, chunk_size=None, unroll=1)[source]
A version of pmap which chunks the input into smaller pieces to avoid memory issues.
- remove_chunk_dim(py_tree)[source]
Remove the chunk dimension from a pytree
- Parameters:
py_tree (T) – pytree to remove chunk dimension from
- Returns:
pytree with chunk dimension removed
- Return type:
T
- add_chunk_dim(py_tree, chunk_size)[source]
Add a chunk dimension to a pytree
- Parameters:
py_tree (T) – pytree to add chunk dimension to
chunk_size (int) – size of chunk dimension
- Returns:
pytree with chunk dimension added
- Return type:
T
- chunked_vmap(f, chunk_size=None, unroll=1)[source]
A version of vmap which chunks the input into smaller pieces to avoid memory issues.
- Parameters:
Returns:
- pytree_unravel(example_tree)[source]
Returns functions to ravel and unravel a pytree.
- Parameters:
example_tree (PT) – a pytree to be unravelled
- Returns:
a function to ravel the pytree unravel_fun: a function to unravel
- Return type:
ravel_fun
- create_mesh(shape, axis_names, devices=None)[source]
Create a mesh from a shape and axis names.
- Parameters:
shape – the shape of the mesh, total size must evenly divide number of devices.
axis_names – the axis names of the mesh.
devices – the devices to use, if None, uses all devices.
- Returns:
the mesh
- tree_device_put(tree, mesh, axis_names)[source]
Put a pytree on a device.
- Parameters:
tree (SPT) – the pytree to put on a device.
mesh (jax._src.mesh.Mesh) – the mesh to put the pytree on.
axis_names (Tuple[Union[str, None], Ellipsis]) – the axis names of the mesh.
- Returns:
the pytree on the device.
- Return type:
SPT