model
jaxns.framework.model
Module Contents
- class Model(prior_model, log_likelihood, params=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractModelRepresents a Bayesian model in terms of a generative prior, and likelihood function.
- Parameters:
prior_model (jaxns.framework.bases.PriorModelType)
log_likelihood (jaxns.internals.types.LikelihoodType)
params (Optional[jaxns.framework.context.MutableParams])
- set_params(params)[source]
Create a new parametrised model with the given parameters.
- Parameters:
params (jaxns.framework.context.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- __call__(params)[source]
Create a new parametrised model with the given parameters.
This is (and must be) a pure function.
- Parameters:
params (jaxns.framework.context.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- init_params(rng)[source]
Initialise the parameters of the model.
- Parameters:
rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.
- Returns:
The initialised parameters.
- Return type:
jaxns.framework.context.MutableParams
- sample_U(key)[source]
Sample from the prior model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey to use.
- Returns:
The sampled U.
- Return type:
jaxns.internals.types.UType
- sample_W(key)[source]
Sample from the prior model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey to use.
- Returns:
The sampled W.
- Return type:
jaxns.internals.types.WType
- transform(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.XType
- transform_parametrised(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.XType
- forward(U, allow_nan=False)[source]
- Parameters:
U (jaxns.internals.types.UType)
allow_nan (bool)
- Return type:
jaxns.internals.types.FloatArray
- log_prob_prior(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.FloatArray