context

jaxns.framework.context

Module Contents

class ScopedDict(_dict=None, _scopes=None)[source]

prefixes all keys with a given scope {scope}.{key}

scopes: List[str] = [][source]
dict[source]
push_scope(scope)[source]
pop_scope()[source]
property scope_prefix[source]
to_dict()[source]
__repr__()[source]
__getitem__(item)[source]
__setitem__(key, value)[source]
__contains__(item)[source]
__iter__()[source]
__len__()[source]
keys()[source]
values()[source]
items()[source]
scope(name)[source]

Create a new scope, to prefix parameters and states, as {current_scope}.{name}.{param_name}.

Parameters:

name (str) – the name of the scope

Returns:

The scope

get_parameter(name, shape=None, dtype=None, *, init=default_init)[source]

Get a parameter variable.

Parameters:
  • name (str) – the name of the parameter

  • shape (Optional[Tuple[int, Ellipsis]]) – the shape of the parameter must be provided if init is not a jax.Array

  • dtype (Optional[jax._src.typing.SupportsDType]) – the dtype of the parameter must be provided if init is not a jax.Array

  • init (InitType) – the initializer

Returns:

The parameter variable as a jax.Array=

Return type:

PT

convert_external_params(external_params, prefix)[source]

Convert external parameters to context parameters. This can be used to convert haiku or flax parameters to jaxns parameters for using in models.

Parameters:
  • external_params (ExtParam) – map of name -> value

  • prefix (str)

Returns:

The context parameters

Return type:

ExtParam

wrap_random(f)[source]

Wrap a function to use a random number generator from the context.

Parameters:

f – the function to wrap

Returns:

The wrapped function

get_state(name, shape=None, dtype=None, *, init=default_init)[source]

Get a state variable.

Parameters:
  • name (str) – the name of the state

  • shape (Optional[Tuple[int, Ellipsis]]) – the shape of the state must be provided if init is not a jax.Array

  • dtype (Optional[jax._src.typing.SupportsDType]) – the dtype of the state must be provided if init is not a jax.Array

  • init (InitType) – the initializer

Returns:

The state variable as a jax.Array

Return type:

PT

set_state(name, value)[source]

Set a state variable.

Parameters:
  • name (str) – the name of the state

  • value (PT) – the value to set

Returns:

The state variable as a jax.Array

transform_with_state(f)[source]

Transform a function to use parameters and states.

Parameters:

f (Callable) – the function to transform

Returns:

A tuple of the init and apply functions

Return type:

TransformedWithStateFn

transform(f)[source]

Transform a function to use parameters and states.

Parameters:

f (Callable) – the function to transform

Returns:

A tuple of the init and apply functions

Return type:

TransformedFn

next_rng_key()[source]

Get the next random number generator

Returns:

The next random number generator