framework
jaxns.framework
Submodules
Package Contents
- jaxify_likelihood(log_likelihood, vectorised=False)[source]
Wraps a non-JAX log likelihood function.
- Parameters:
log_likelihood (Callable[Ellipsis, numpy.ndarray]) – a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood.
vectorised (bool) – if True then the log_likelihood must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle.
- Returns:
A JAX-compatible log-likelihood function.
- Return type:
jaxns.internals.types.LikelihoodType
- class Model(prior_model, log_likelihood, params=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractModelRepresents a Bayesian model in terms of a generative prior, and likelihood function.
- Parameters:
prior_model (jaxns.framework.bases.PriorModelType)
log_likelihood (jaxns.internals.types.LikelihoodType)
params (Optional[jaxns.framework.context.MutableParams])
- property params
- set_params(params)[source]
Create a new parametrised model with the given parameters.
- Parameters:
params (jaxns.framework.context.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- __call__(params)[source]
Create a new parametrised model with the given parameters.
This is (and must be) a pure function.
- Parameters:
params (jaxns.framework.context.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- init_params(rng)[source]
Initialise the parameters of the model.
- Parameters:
rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.
- Returns:
The initialised parameters.
- Return type:
jaxns.framework.context.MutableParams
- sample_U(key)[source]
Sample from the prior model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey to use.
- Returns:
The sampled U.
- Return type:
jaxns.internals.types.UType
- sample_W(key)[source]
Sample from the prior model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey to use.
- Returns:
The sampled W.
- Return type:
jaxns.internals.types.WType
- transform(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.XType
- transform_parametrised(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.XType
- forward(U, allow_nan=False)[source]
- Parameters:
U (jaxns.internals.types.UType)
allow_nan (bool)
- Return type:
jaxns.internals.types.FloatArray
- log_prob_prior(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.FloatArray
- class Prior(dist_or_value, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPriorRepresents a generative prior.
- Parameters:
dist_or_value (Union[tfpd, jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray, jaxns.internals.types.BoolArray])
name (Optional[str])
- name = None
- property dist: jaxns.framework.bases.BaseAbstractDistribution
- Return type:
jaxns.framework.bases.BaseAbstractDistribution
- parametrised(random_init=False)[source]
Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a get_parameter with added _param name suffix. Prior must have a name.
- Parameters:
random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.
- Returns:
A singular prior.
- Raises:
ValueError – if the prior has no name.
- Return type:
SingularPrior
- exception InvalidPriorName(name=None)[source]
Bases:
ExceptionRaised when a prior name is already taken.
Initialize self. See help(type(self)) for accurate signature.
- Parameters:
name (Optional[str])
- class Bernoulli(*, logits=None, probs=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- dist
- class Beta(*, concentration0=None, concentration1=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
Initialised Categorical special prior.
- Parameters:
parametrisation (Literal['gumbel_max', 'cdf']) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.
logits – log-prob of each category
probs – prob of each category
name (Optional[str]) – optional name
- dist
- class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]
Bases:
SpecialPriorPrior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.
- Parameters:
- n
- low
- high
- fix_left = False
- fix_right = False
- class Poisson(*, rate=None, log_rate=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- dist
- class UnnormalisedDirichlet(*, concentration, name=None)[source]
Bases:
SpecialPriorRepresents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.
X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)
- Parameters:
name (Optional[str])
- class Empirical(*, samples, support_min=None, support_max=None, resolution=100, name=None)[source]
Bases:
SpecialPriorRepresents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.
- class TruncationWrapper(prior, low, high, name=None)[source]
Bases:
SpecialPriorWraps another prior to make it truncated.
For truncated distribution the quantile transforms to:
Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))
And the CDF transforms to:
F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))
- Parameters:
- prior
- low
- high
- cdf_low
- cdf_diff