maps

jaxns.internals.maps

Module Contents

replace_index(operand, update, start_index)[source]

Replaces an index or slice with an update. If update is too big to respect start_index then start_index is shifted, which will give non-intuitive results.

get_index(operand, start_index, length)[source]
prepare_func_args(f)[source]

Takes a callable(a,b,…,z=Z) and prepares it into callable(**kwargs), such that only a,b,…,z are taken from **kwargs and the rest ignored.

This allows f(**kwarg) to work even if f() is missing some keys from kwargs.

Parameters:

fcallable(a,b,...,z=Z)

Returns:

callable(**kwargs) where **kwargs are the filtered for args of the original function.

F[source]
FV[source]
chunked_pmap(f, chunk_size=None, unroll=1)[source]

A version of pmap which chunks the input into smaller pieces to avoid memory issues.

Parameters:
  • f (Callable[Ellipsis, FV]) – callable

  • chunk_size (Optional[int]) – the size of the chunks. Default is len(devices())

  • unroll (int) – the number of times to unroll the computation

Returns:

a chunked version of f

Return type:

Callable[Ellipsis, FV]

prepad(a, chunksize)[source]
Parameters:

chunksize (int)

T[source]
remove_chunk_dim(py_tree)[source]

Remove the chunk dimension from a pytree

Parameters:

py_tree (T) – pytree to remove chunk dimension from

Returns:

pytree with chunk dimension removed

Return type:

T

add_chunk_dim(py_tree, chunk_size)[source]

Add a chunk dimension to a pytree

Parameters:
  • py_tree (T) – pytree to add chunk dimension to

  • chunk_size (int) – size of chunk dimension

Returns:

pytree with chunk dimension added

Return type:

T

chunked_vmap(f, chunk_size=None, unroll=1)[source]

A version of vmap which chunks the input into smaller pieces to avoid memory issues.

Parameters:
  • f – the function to be mapped

  • chunk_size (Optional[int]) – the size of the chunks. Default is len(devices())

  • unroll (int) – the number of times to unroll the computation

Returns:

PT[source]
pytree_unravel(example_tree)[source]

Returns functions to ravel and unravel a pytree.

Parameters:

example_tree (PT) – a pytree to be unravelled

Returns:

a function to ravel the pytree unravel_fun: a function to unravel

Return type:

ravel_fun

pytree_unpack(example_tree)[source]

Returns functions to ravel and unravel a pytree.

Parameters:

example_tree (PT)

Return type:

Tuple[Callable[[PT], List[jax.Array]], Callable[[List[jax.Array]], PT]]

PV[source]
class PyTree(tree)[source]

For acting on W space.

Parameters:

tree (PV)

tree[source]
__add__(other)[source]
Parameters:

other (PV)

Return type:

PV

__sub__(other)[source]
Parameters:

other (PV)

Return type:

PV

__mul__(other)[source]
Parameters:

other (PV)

Return type:

PV

__truediv__(other)[source]
Parameters:

other (PV)

Return type:

PV

__pow__(other)[source]
Parameters:

other (PV)

Return type:

PV

__neg__()[source]
create_mesh(shape, axis_names, devices=None)[source]

Create a mesh from a shape and axis names.

Parameters:
  • shape – the shape of the mesh, total size must evenly divide number of devices.

  • axis_names – the axis names of the mesh.

  • devices – the devices to use, if None, uses all devices.

Returns:

the mesh

SPT[source]
tree_device_put(tree, mesh, axis_names)[source]

Put a pytree on a device.

Parameters:
  • tree (SPT) – the pytree to put on a device.

  • mesh (jax._src.mesh.Mesh) – the mesh to put the pytree on.

  • axis_names (Tuple[Union[str, None], Ellipsis]) – the axis names of the mesh.

Returns:

the pytree on the device.

Return type:

SPT

BUX[source]
block_until_ready(x)[source]
Parameters:

x (BUX)

Return type:

BUX