context
jaxns.framework.context
Module Contents
- class ScopedDict(_dict=None, _scopes=None)[source]
prefixes all keys with a given scope {scope}.{key}
- scope(name)[source]
Create a new scope, to prefix parameters and states, as {current_scope}.{name}.{param_name}.
- Parameters:
name (str) – the name of the scope
- Returns:
The scope
- get_parameter(name, shape=None, dtype=None, *, init=default_init)[source]
Get a parameter variable.
- Parameters:
name (str) – the name of the parameter
shape (Optional[Tuple[int, Ellipsis]]) – the shape of the parameter must be provided if init is not a jax.Array
dtype (Optional[jax._src.typing.SupportsDType]) – the dtype of the parameter must be provided if init is not a jax.Array
init (InitType) – the initializer
- Returns:
The parameter variable as a jax.Array=
- Return type:
PT
- convert_external_params(external_params, prefix)[source]
Convert external parameters to context parameters. This can be used to convert haiku or flax parameters to jaxns parameters for using in models.
- Parameters:
external_params (ExtParam) – map of name -> value
prefix (str)
- Returns:
The context parameters
- Return type:
ExtParam
- wrap_random(f)[source]
Wrap a function to use a random number generator from the context.
- Parameters:
f – the function to wrap
- Returns:
The wrapped function
- get_state(name, shape=None, dtype=None, *, init=default_init)[source]
Get a state variable.
- Parameters:
- Returns:
The state variable as a jax.Array
- Return type:
PT
- set_state(name, value)[source]
Set a state variable.
- Parameters:
name (str) – the name of the state
value (PT) – the value to set
- Returns:
The state variable as a jax.Array
- transform_with_state(f)[source]
Transform a function to use parameters and states.
- Parameters:
f (Callable) – the function to transform
- Returns:
A tuple of the init and apply functions
- Return type:
TransformedWithStateFn