framework

jaxns.framework

Submodules

Package Contents

class Model(prior_model, log_likelihood, params=None)[source]

Bases: jaxns.framework.bases.BaseAbstractModel

Represents a Bayesian model in terms of a generative prior, and likelihood function.

Parameters:
  • prior_model (jaxns.framework.bases.PriorModelType) –

  • log_likelihood (jaxns.internals.types.LikelihoodType) –

  • params (Optional[haiku.MutableParams]) –

property num_params: int
Return type:

int

property params
set_params(params)[source]

Create a new parametrised model with the given parameters.

Parameters:

params (haiku.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

__call__(params)[source]

Create a new parametrised model with the given parameters.

This is (and must be) a pure function.

Parameters:

params (haiku.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

init_params(rng)[source]

Initialise the parameters of the model.

Parameters:

rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.

Returns:

The initialised parameters.

Return type:

haiku.MutableParams

__hash__()[source]
__repr__()[source]
sample_U(key)[source]
Parameters:

key (jaxns.internals.types.PRNGKey) –

Return type:

jaxns.internals.types.FloatArray

transform(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.XType

transform_parametrised(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.XType

forward(U, allow_nan=False)[source]
Parameters:
  • U (jaxns.internals.types.UType) –

  • allow_nan (bool) –

Return type:

jaxns.internals.types.FloatArray

log_prob_prior(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.FloatArray

prepare_input(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.LikelihoodInputType

sanity_check(key, S)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • S (int) –

class Prior(dist_or_value, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents a generative prior.

Parameters:
property dist: jaxns.framework.bases.BaseAbstractDistribution
Return type:

jaxns.framework.bases.BaseAbstractDistribution

property value: jax.numpy.ndarray
Return type:

jax.numpy.ndarray

parametrised(random_init=False)[source]

Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a hk.Parameter with added _param name suffix.

Parameters:

random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.

Returns:

A singular prior.

Return type:

SingularPrior

exception InvalidPriorName(name=None)[source]

Bases: Exception

Raised when a prior name is already taken.

Initialize self. See help(type(self)) for accurate signature.

Parameters:

name (Optional[str]) –

class Bernoulli(*, logits=None, probs=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class Beta(*, concentration0=None, concentration1=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Initialised Categorical special prior.

Parameters:
  • parametrisation (Literal[gumbel_max, cdf]) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.

  • logits – log-prob of each category

  • probs – prob of each category

  • name (Optional[str]) – optional name

class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.

Parameters:
  • n (int) – number of samples within [low,high]

  • low – minimum of distribution

  • high – maximum of distribution

  • fix_left (bool) – if True, the leftmost value is fixed to low

  • fix_right (bool) – if True, the rightmost value is fixed to high

  • name (Optional[str]) –

class Poisson(*, rate=None, log_rate=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class UnnormalisedDirichlet(*, concentration, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.

X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)

Parameters:

name (Optional[str]) –

PriorModelGen[source]
PriorModelType[source]