framework =================== .. py:module:: jaxns.framework .. rubric:: :code:`jaxns.framework` .. rubric:: Submodules .. toctree:: :titlesonly: :maxdepth: 1 abc/index.rst bases/index.rst context/index.rst jaxify/index.rst model/index.rst ops/index.rst prior/index.rst special_priors/index.rst wrapped_tfp_distribution/index.rst .. rubric:: Package Contents .. py:data:: PriorModelGen .. py:data:: PriorModelType .. py:function:: jaxify_likelihood(log_likelihood, vectorised = False) Wraps a non-JAX log likelihood function. :param log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood. :param vectorised: if True then the `log_likelihood` must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle. :returns: A JAX-compatible log-likelihood function. .. py:class:: Model(prior_model, log_likelihood, params = None) Bases: :py:obj:`jaxns.framework.bases.BaseAbstractModel` Represents a Bayesian model in terms of a generative prior, and likelihood function. .. py:method:: __repr__() .. py:property:: num_params :type: int .. py:property:: params .. py:method:: set_params(params) Create a new parametrised model with the given parameters. :param params: The parameters to use. :returns: A model with set parameters. .. py:method:: __call__(params) Create a new parametrised model with the given parameters. **This is (and must be) a pure function.** :param params: The parameters to use. :returns: A model with set parameters. .. py:method:: init_params(rng) Initialise the parameters of the model. :param rng: PRNGkey to initialise the parameters. :returns: The initialised parameters. .. py:method:: __hash__() .. py:method:: sample_U(key) Sample from the prior model. :param key: PRNGKey to use. :returns: The sampled U. .. py:method:: sample_W(key) Sample from the prior model. :param key: PRNGKey to use. :returns: The sampled W. .. py:method:: transform(U) .. py:method:: transform_parametrised(U) .. py:method:: forward(U, allow_nan = False) .. py:method:: log_prob_prior(U) .. py:method:: prepare_input(U) .. py:method:: sanity_check(key, S) .. py:class:: Prior(dist_or_value, name = None) Bases: :py:obj:`jaxns.framework.bases.BaseAbstractPrior` Represents a generative prior. .. py:attribute:: name :value: None .. py:property:: dist :type: jaxns.framework.bases.BaseAbstractDistribution .. py:property:: value :type: jax.Array .. py:method:: parametrised(random_init = False) Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a `get_parameter` with added `_param` name suffix. Prior must have a name. :param random_init: whether to initialise the parameter randomly or at the median of the distribution. :returns: A singular prior. :raises ValueError: if the prior has no name. .. py:exception:: InvalidPriorName(name = None) Bases: :py:obj:`Exception` Raised when a prior name is already taken. Initialize self. See help(type(self)) for accurate signature. .. py:class:: Bernoulli(*, logits=None, probs=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:attribute:: dist .. py:class:: Beta(*, concentration0=None, concentration1=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:class:: Categorical(parametrisation, *, logits=None, probs=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. Initialised Categorical special prior. :param parametrisation: 'cdf' is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better. :param logits: log-prob of each category :param probs: prob of each category :param name: optional name .. py:attribute:: dist .. py:class:: ForcedIdentifiability(*, n, low=None, high=None, fix_left = False, fix_right = False, name = None) Bases: :py:obj:`SpecialPrior` Prior for a sequence of `n` random variables uniformly distributed on U[low, high] such that U[i,...] <= U[i+1,...]. For broadcasting the resulting random variable is sorted on the first dimension elementwise. :param n: number of samples within [low,high] :param low: minimum of distribution :param high: maximum of distribution :param fix_left: if True, the leftmost value is fixed to `low` :param fix_right: if True, the rightmost value is fixed to `high` .. py:attribute:: n .. py:attribute:: low .. py:attribute:: high .. py:attribute:: fix_left :value: False .. py:attribute:: fix_right :value: False .. py:class:: Poisson(*, rate=None, log_rate=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:attribute:: dist .. py:class:: UnnormalisedDirichlet(*, concentration, name = None) Bases: :py:obj:`SpecialPrior` Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation. X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha) .. py:class:: Empirical(*, samples, support_min = None, support_max = None, resolution = 100, name = None) Bases: :py:obj:`SpecialPrior` Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension. .. py:class:: TruncationWrapper(prior, low, high, name = None) Bases: :py:obj:`SpecialPrior` Wraps another prior to make it truncated. For truncated distribution the quantile transforms to: Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low)) And the CDF transforms to: F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low)) .. py:attribute:: prior .. py:attribute:: low .. py:attribute:: high .. py:attribute:: cdf_low .. py:attribute:: cdf_diff .. py:class:: ExplicitDensityPrior(*, axes, density, regular_grid = False, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods.