framework

jaxns.framework

Submodules

Package Contents

PriorModelGen[source]
PriorModelType[source]
jaxify_likelihood(log_likelihood, vectorised=False)[source]

Wraps a non-JAX log likelihood function.

Parameters:
  • log_likelihood (Callable[Ellipsis, numpy.ndarray]) – a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood.

  • vectorised (bool) – if True then the log_likelihood must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle.

Returns:

A JAX-compatible log-likelihood function.

Return type:

jaxns.internals.types.LikelihoodType

class Model(prior_model, log_likelihood, params=None)[source]

Bases: jaxns.framework.bases.BaseAbstractModel

Represents a Bayesian model in terms of a generative prior, and likelihood function.

Parameters:
  • prior_model (jaxns.framework.bases.PriorModelType)

  • log_likelihood (jaxns.internals.types.LikelihoodType)

  • params (Optional[jaxns.framework.context.MutableParams])

__repr__()[source]
property num_params: int
Return type:

int

property params
set_params(params)[source]

Create a new parametrised model with the given parameters.

Parameters:

params (jaxns.framework.context.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

__call__(params)[source]

Create a new parametrised model with the given parameters.

This is (and must be) a pure function.

Parameters:

params (jaxns.framework.context.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

init_params(rng)[source]

Initialise the parameters of the model.

Parameters:

rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.

Returns:

The initialised parameters.

Return type:

jaxns.framework.context.MutableParams

__hash__()[source]
sample_U(key)[source]

Sample from the prior model.

Parameters:

key (jaxns.internals.types.PRNGKey) – PRNGKey to use.

Returns:

The sampled U.

Return type:

jaxns.internals.types.UType

sample_W(key)[source]

Sample from the prior model.

Parameters:

key (jaxns.internals.types.PRNGKey) – PRNGKey to use.

Returns:

The sampled W.

Return type:

jaxns.internals.types.WType

transform(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.XType

transform_parametrised(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.XType

forward(U, allow_nan=False)[source]
Parameters:
  • U (jaxns.internals.types.UType)

  • allow_nan (bool)

Return type:

jaxns.internals.types.FloatArray

log_prob_prior(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.FloatArray

prepare_input(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.LikelihoodInputType

sanity_check(key, S)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • S (int)

class Prior(dist_or_value, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents a generative prior.

Parameters:
  • dist_or_value (Union[tfpd, jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray, jaxns.internals.types.BoolArray])

  • name (Optional[str])

name = None
property dist: jaxns.framework.bases.BaseAbstractDistribution
Return type:

jaxns.framework.bases.BaseAbstractDistribution

property value: jax.Array
Return type:

jax.Array

parametrised(random_init=False)[source]

Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a get_parameter with added _param name suffix. Prior must have a name.

Parameters:

random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.

Returns:

A singular prior.

Raises:

ValueError – if the prior has no name.

Return type:

SingularPrior

exception InvalidPriorName(name=None)[source]

Bases: Exception

Raised when a prior name is already taken.

Initialize self. See help(type(self)) for accurate signature.

Parameters:

name (Optional[str])

class Bernoulli(*, logits=None, probs=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

dist
class Beta(*, concentration0=None, concentration1=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Initialised Categorical special prior.

Parameters:
  • parametrisation (Literal['gumbel_max', 'cdf']) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.

  • logits – log-prob of each category

  • probs – prob of each category

  • name (Optional[str]) – optional name

dist
class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]

Bases: SpecialPrior

Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.

Parameters:
  • n (int) – number of samples within [low,high]

  • low – minimum of distribution

  • high – maximum of distribution

  • fix_left (bool) – if True, the leftmost value is fixed to low

  • fix_right (bool) – if True, the rightmost value is fixed to high

  • name (Optional[str])

n
low
high
fix_left = False
fix_right = False
class Poisson(*, rate=None, log_rate=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

dist
class UnnormalisedDirichlet(*, concentration, name=None)[source]

Bases: SpecialPrior

Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.

X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)

Parameters:

name (Optional[str])

class Empirical(*, samples, support_min=None, support_max=None, resolution=100, name=None)[source]

Bases: SpecialPrior

Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.

Parameters:
  • samples (jax.Array)

  • support_min (Optional[jaxns.internals.types.FloatArray])

  • support_max (Optional[jaxns.internals.types.FloatArray])

  • resolution (int)

  • name (Optional[str])

class TruncationWrapper(prior, low, high, name=None)[source]

Bases: SpecialPrior

Wraps another prior to make it truncated.

For truncated distribution the quantile transforms to:

Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))

And the CDF transforms to:

F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))

Parameters:
prior
low
high
cdf_low
cdf_diff
class ExplicitDensityPrior(*, axes, density, regular_grid=False, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters: