model
jaxns.framework.model
Module Contents
- class Model(prior_model, log_likelihood, params=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractModel
Represents a Bayesian model in terms of a generative prior, and likelihood function.
- Parameters:
prior_model (jaxns.framework.bases.PriorModelType) –
log_likelihood (jaxns.internals.types.LikelihoodType) –
params (Optional[haiku.MutableParams]) –
- set_params(params)[source]
Create a new parametrised model with the given parameters.
- Parameters:
params (haiku.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- __call__(params)[source]
Create a new parametrised model with the given parameters.
This is (and must be) a pure function.
- Parameters:
params (haiku.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- init_params(rng)[source]
Initialise the parameters of the model.
- Parameters:
rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.
- Returns:
The initialised parameters.
- Return type:
haiku.MutableParams
- sample_U(key)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
- Return type:
jaxns.internals.types.FloatArray
- transform(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.XType
- transform_parametrised(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.XType
- forward(U, allow_nan=False)[source]
- Parameters:
U (jaxns.internals.types.UType) –
allow_nan (bool) –
- Return type:
jaxns.internals.types.FloatArray
- log_prob_prior(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.FloatArray