special_priors

jaxns.framework.special_priors

Module Contents

class Bernoulli(*, logits=None, probs=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

dist[source]
class Beta(*, concentration0=None, concentration1=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Initialised Categorical special prior.

Parameters:
  • parametrisation (Literal['gumbel_max', 'cdf']) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.

  • logits – log-prob of each category

  • probs – prob of each category

  • name (Optional[str]) – optional name

dist[source]
class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]

Bases: SpecialPrior

Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.

Parameters:
  • n (int) – number of samples within [low,high]

  • low – minimum of distribution

  • high – maximum of distribution

  • fix_left (bool) – if True, the leftmost value is fixed to low

  • fix_right (bool) – if True, the rightmost value is fixed to high

  • name (Optional[str])

n[source]
low[source]
high[source]
fix_left = False[source]
fix_right = False[source]
class Poisson(*, rate=None, log_rate=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

dist[source]
class UnnormalisedDirichlet(*, concentration, name=None)[source]

Bases: SpecialPrior

Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.

X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)

Parameters:

name (Optional[str])

class Empirical(*, samples, support_min=None, support_max=None, resolution=100, name=None)[source]

Bases: SpecialPrior

Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.

Parameters:
  • samples (jax.Array)

  • support_min (Optional[jaxns.internals.types.FloatArray])

  • support_max (Optional[jaxns.internals.types.FloatArray])

  • resolution (int)

  • name (Optional[str])

class TruncationWrapper(prior, low, high, name=None)[source]

Bases: SpecialPrior

Wraps another prior to make it truncated.

For truncated distribution the quantile transforms to:

Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))

And the CDF transforms to:

F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))

Parameters:
prior[source]
low[source]
high[source]
cdf_low[source]
cdf_diff[source]
class ExplicitDensityPrior(*, axes, density, regular_grid=False, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters: