special_priors

jaxns.framework.special_priors

Module Contents

class Bernoulli(*, logits=None, probs=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class Beta(*, concentration0=None, concentration1=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Initialised Categorical special prior.

Parameters:
  • parametrisation (Literal[gumbel_max, cdf]) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.

  • logits – log-prob of each category

  • probs – prob of each category

  • name (Optional[str]) – optional name

class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.

Parameters:
  • n (int) – number of samples within [low,high]

  • low – minimum of distribution

  • high – maximum of distribution

  • fix_left (bool) – if True, the leftmost value is fixed to low

  • fix_right (bool) – if True, the rightmost value is fixed to high

  • name (Optional[str]) –

class Poisson(*, rate=None, log_rate=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class UnnormalisedDirichlet(*, concentration, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.

X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)

Parameters:

name (Optional[str]) –