maps ============== .. py:module:: jaxns.internals.maps .. rubric:: :code:`jaxns.internals.maps` .. rubric:: Module Contents .. py:function:: replace_index(operand, update, start_index) Replaces an index or slice with an update. If update is too big to respect start_index then start_index is shifted, which will give non-intuitive results. .. py:function:: get_index(operand, start_index, length) .. py:function:: prepare_func_args(f) Takes a callable(a,b,...,z=Z) and prepares it into ``callable(**kwargs)``, such that only a,b,...,z are taken from ``**kwargs`` and the rest ignored. This allows ``f(**kwarg)`` to work even if ``f()`` is missing some keys from kwargs. :param f: ``callable(a,b,...,z=Z)`` :returns: ``callable(**kwargs)`` where ``**kwargs`` are the filtered for args of the original function. .. py:data:: F .. py:data:: FV .. py:function:: chunked_pmap(f, chunk_size = None, unroll = 1) A version of pmap which chunks the input into smaller pieces to avoid memory issues. :param f: callable :param chunk_size: the size of the chunks. Default is len(devices()) :param unroll: the number of times to unroll the computation :returns: a chunked version of f .. py:function:: prepad(a, chunksize) .. py:data:: T .. py:function:: remove_chunk_dim(py_tree) Remove the chunk dimension from a pytree :param py_tree: pytree to remove chunk dimension from :returns: pytree with chunk dimension removed .. py:function:: add_chunk_dim(py_tree, chunk_size) Add a chunk dimension to a pytree :param py_tree: pytree to add chunk dimension to :param chunk_size: size of chunk dimension :returns: pytree with chunk dimension added .. py:function:: chunked_vmap(f, chunk_size = None, unroll = 1) A version of vmap which chunks the input into smaller pieces to avoid memory issues. :param f: the function to be mapped :param chunk_size: the size of the chunks. Default is len(devices()) :param unroll: the number of times to unroll the computation Returns: .. py:data:: PT .. py:function:: pytree_unravel(example_tree) Returns functions to ravel and unravel a pytree. :param example_tree: a pytree to be unravelled :returns: a function to ravel the pytree unravel_fun: a function to unravel :rtype: ravel_fun .. py:function:: pytree_unpack(example_tree) Returns functions to ravel and unravel a pytree. .. py:data:: PV .. py:class:: PyTree(tree) For acting on W space. .. py:attribute:: tree .. py:method:: __add__(other) .. py:method:: __sub__(other) .. py:method:: __mul__(other) .. py:method:: __truediv__(other) .. py:method:: __pow__(other) .. py:method:: __neg__() .. py:function:: create_mesh(shape, axis_names, devices=None) Create a mesh from a shape and axis names. :param shape: the shape of the mesh, total size must evenly divide number of devices. :param axis_names: the axis names of the mesh. :param devices: the devices to use, if None, uses all devices. :returns: the mesh .. py:data:: SPT .. py:function:: tree_device_put(tree, mesh, axis_names) Put a pytree on a device. :param tree: the pytree to put on a device. :param mesh: the mesh to put the pytree on. :param axis_names: the axis names of the mesh. :returns: the pytree on the device. .. py:data:: BUX .. py:function:: block_until_ready(x)