context ================= .. py:module:: jaxns.framework.context .. rubric:: :code:`jaxns.framework.context` .. rubric:: Module Contents .. py:class:: ScopedDict(_dict=None, _scopes=None) prefixes all keys with a given scope {scope}.{key} .. py:attribute:: scopes :type: List[str] :value: [] .. py:attribute:: dict .. py:method:: push_scope(scope) .. py:method:: pop_scope() .. py:property:: scope_prefix .. py:method:: to_dict() .. py:method:: __repr__() .. py:method:: __getitem__(item) .. py:method:: __setitem__(key, value) .. py:method:: __contains__(item) .. py:method:: __iter__() .. py:method:: __len__() .. py:method:: keys() .. py:method:: values() .. py:method:: items() .. py:function:: scope(name) Create a new scope, to prefix parameters and states, as {current_scope}.{name}.{param_name}. :param name: the name of the scope :returns: The scope .. py:function:: get_parameter(name, shape = None, dtype = None, *, init = default_init) Get a parameter variable. :param name: the name of the parameter :param shape: the shape of the parameter must be provided if init is not a jax.Array :param dtype: the dtype of the parameter must be provided if init is not a jax.Array :param init: the initializer :returns: The parameter variable as a jax.Array= .. py:function:: convert_external_params(external_params, prefix) Convert external parameters to context parameters. This can be used to convert haiku or flax parameters to jaxns parameters for using in models. :param external_params: map of name -> value :returns: The context parameters .. py:function:: wrap_random(f) Wrap a function to use a random number generator from the context. :param f: the function to wrap :returns: The wrapped function .. py:function:: get_state(name, shape = None, dtype = None, *, init = default_init) Get a state variable. :param name: the name of the state :param shape: the shape of the state must be provided if init is not a jax.Array :param dtype: the dtype of the state must be provided if init is not a jax.Array :param init: the initializer :returns: The state variable as a jax.Array .. py:function:: set_state(name, value) Set a state variable. :param name: the name of the state :param value: the value to set :returns: The state variable as a jax.Array .. py:function:: transform_with_state(f) Transform a function to use parameters and states. :param f: the function to transform :returns: A tuple of the init and apply functions .. py:function:: transform(f) Transform a function to use parameters and states. :param f: the function to transform :returns: A tuple of the init and apply functions .. py:function:: next_rng_key() Get the next random number generator :returns: The next random number generator