model

jaxns.framework.model

Module Contents

class Model(prior_model, log_likelihood, params=None)[source]

Bases: jaxns.framework.bases.BaseAbstractModel

Represents a Bayesian model in terms of a generative prior, and likelihood function.

Parameters:
  • prior_model (jaxns.framework.bases.PriorModelType) –

  • log_likelihood (jaxns.internals.types.LikelihoodType) –

  • params (Optional[haiku.MutableParams]) –

property num_params: int[source]
Return type:

int

property params[source]
set_params(params)[source]

Create a new parametrised model with the given parameters.

Parameters:

params (haiku.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

__call__(params)[source]

Create a new parametrised model with the given parameters.

This is (and must be) a pure function.

Parameters:

params (haiku.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

init_params(rng)[source]

Initialise the parameters of the model.

Parameters:

rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.

Returns:

The initialised parameters.

Return type:

haiku.MutableParams

__hash__()[source]
__repr__()[source]
sample_U(key)[source]
Parameters:

key (jaxns.internals.types.PRNGKey) –

Return type:

jaxns.internals.types.FloatArray

transform(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.XType

transform_parametrised(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.XType

forward(U, allow_nan=False)[source]
Parameters:
  • U (jaxns.internals.types.UType) –

  • allow_nan (bool) –

Return type:

jaxns.internals.types.FloatArray

log_prob_prior(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.FloatArray

prepare_input(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.LikelihoodInputType

sanity_check(key, S)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • S (int) –