special_priors
jaxns.framework.special_priors
Module Contents
- class Bernoulli(*, logits=None, probs=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- class Beta(*, concentration0=None, concentration1=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
Initialised Categorical special prior.
- Parameters:
parametrisation (Literal['gumbel_max', 'cdf']) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.
logits – log-prob of each category
probs – prob of each category
name (Optional[str]) – optional name
- class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]
Bases:
SpecialPriorPrior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.
- Parameters:
- class Poisson(*, rate=None, log_rate=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- class UnnormalisedDirichlet(*, concentration, name=None)[source]
Bases:
SpecialPriorRepresents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.
X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)
- Parameters:
name (Optional[str])
- class Empirical(*, samples, support_min=None, support_max=None, resolution=100, name=None)[source]
Bases:
SpecialPriorRepresents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.
- class TruncationWrapper(prior, low, high, name=None)[source]
Bases:
SpecialPriorWraps another prior to make it truncated.
For truncated distribution the quantile transforms to:
Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))
And the CDF transforms to:
F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))
- Parameters: