abc
jaxns.framework.abc
Module Contents
- class AbstractModel[source]
Bases:
abc.ABC
Represents a Bayesian model in terms of a generative prior, and likelihood function.
- abstract sample_U(key)[source]
Sample uniformly from the prior in U-space.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
- Returns:
U-space sample
- Return type:
jaxns.internals.types.UType
- abstract transform(U)[source]
Compute the prior sample.
- Parameters:
U (jaxns.internals.types.UType) – U-space sample
- Returns:
prior sample
- Return type:
jaxns.internals.types.XType
- abstract transform_parametrised(U)[source]
Compute the parametrised prior variables.
- Parameters:
U (jaxns.internals.types.UType) – U-space sample
- Returns:
prior sample
- Return type:
jaxns.internals.types.XType
- abstract forward(U, allow_nan=False)[source]
Compute the log-likelihood.
- Parameters:
U (jaxns.internals.types.UType) – U-space sample
allow_nan (bool) – whether to allow nans in likelihood
- Returns:
log likelihood at the sample
- Return type:
jaxns.internals.types.MeasureType
- log_prob_likelihood(U, allow_nan=False)[source]
Compute the log-likelihood.
- Parameters:
U (jaxns.internals.types.UType) – U-space sample
allow_nan (bool) – whether to allow nans in likelihood
- Returns:
log likelihood at the sample
- Return type:
jaxns.internals.types.MeasureType
- abstract log_prob_prior(U)[source]
Computes the log-probability of the prior.
- Parameters:
U (jaxns.internals.types.UType) – The U-space sample
- Returns:
the log probability of prior
- Return type:
jaxns.internals.types.MeasureType
- log_prob_joint(U, allow_nan)[source]
Computes the log-joint probability of the model.
- Parameters:
U (jaxns.internals.types.UType) – The U-space sample
allow_nan (bool) – whether to allow nans in likelihood
- Returns:
the log joint probability of the model
- Return type:
jaxns.internals.types.MeasureType
- abstract prepare_input(U)[source]
Prepares the input for the likelihood function.
- Parameters:
U (jaxns.internals.types.UType) – The U-space sample
- Returns:
the input to the likelihood function
- Return type:
jaxns.internals.types.LikelihoodInputType
- abstract sanity_check(key, S)[source]
Performs a sanity check on the model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
S (int) – number of samples to check
- Raises:
AssertionError – if any of the samples are nan.