framework
jaxns.framework
Submodules
Package Contents
- class Model(prior_model, log_likelihood, params=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractModel
Represents a Bayesian model in terms of a generative prior, and likelihood function.
- Parameters:
prior_model (jaxns.framework.bases.PriorModelType) –
log_likelihood (jaxns.internals.types.LikelihoodType) –
params (Optional[haiku.MutableParams]) –
- property params
- set_params(params)[source]
Create a new parametrised model with the given parameters.
- Parameters:
params (haiku.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- __call__(params)[source]
Create a new parametrised model with the given parameters.
This is (and must be) a pure function.
- Parameters:
params (haiku.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- init_params(rng)[source]
Initialise the parameters of the model.
- Parameters:
rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.
- Returns:
The initialised parameters.
- Return type:
haiku.MutableParams
- sample_U(key)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
- Return type:
jaxns.internals.types.FloatArray
- transform(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.XType
- transform_parametrised(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.XType
- forward(U, allow_nan=False)[source]
- Parameters:
U (jaxns.internals.types.UType) –
allow_nan (bool) –
- Return type:
jaxns.internals.types.FloatArray
- log_prob_prior(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.FloatArray
- class Prior(dist_or_value, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
Represents a generative prior.
- Parameters:
dist_or_value (Union[tfpd, jaxns.framework.bases.BaseAbstractDistribution, jax.numpy.ndarray]) –
name (Optional[str]) –
- property dist: jaxns.framework.bases.BaseAbstractDistribution
- Return type:
- property value: jax.numpy.ndarray
- Return type:
jax.numpy.ndarray
- parametrised(random_init=False)[source]
Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a hk.Parameter with added _param name suffix.
- Parameters:
random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.
- Returns:
A singular prior.
- Return type:
SingularPrior
- exception InvalidPriorName(name=None)[source]
Bases:
Exception
Raised when a prior name is already taken.
Initialize self. See help(type(self)) for accurate signature.
- Parameters:
name (Optional[str]) –
- class Bernoulli(*, logits=None, probs=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
- Parameters:
name (Optional[str]) –
- class Beta(*, concentration0=None, concentration1=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
- Parameters:
name (Optional[str]) –
- class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
Initialised Categorical special prior.
- Parameters:
parametrisation (Literal[gumbel_max, cdf]) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.
logits – log-prob of each category
probs – prob of each category
name (Optional[str]) – optional name
- class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.
- class Poisson(*, rate=None, log_rate=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
- Parameters:
name (Optional[str]) –
- class UnnormalisedDirichlet(*, concentration, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.
X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)
- Parameters:
name (Optional[str]) –