jaxns =============== .. py:module:: jaxns :synopsis: Nested sampling with JAX. .. rubric:: :code:`jaxns` .. autoapi-nested-parse:: Nested sampling with JAX. .. rubric:: Subpackages .. toctree:: :titlesonly: :maxdepth: 1 experimental/index.rst framework/index.rst internals/index.rst nested_samplers/index.rst samplers/index.rst .. rubric:: Submodules .. toctree:: :titlesonly: :maxdepth: 1 plotting/index.rst public/index.rst utils/index.rst warnings/index.rst .. rubric:: Package Contents .. py:data:: PriorModelGen .. py:data:: PriorModelType .. py:function:: jaxify_likelihood(log_likelihood, vectorised = False) Wraps a non-JAX log likelihood function. :param log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood. :param vectorised: if True then the `log_likelihood` must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle. :returns: A JAX-compatible log-likelihood function. .. py:class:: Model(prior_model, log_likelihood, params = None) Bases: :py:obj:`jaxns.framework.bases.BaseAbstractModel` Represents a Bayesian model in terms of a generative prior, and likelihood function. .. py:method:: __repr__() .. py:property:: num_params :type: int .. py:property:: params .. py:method:: set_params(params) Create a new parametrised model with the given parameters. :param params: The parameters to use. :returns: A model with set parameters. .. py:method:: __call__(params) Create a new parametrised model with the given parameters. **This is (and must be) a pure function.** :param params: The parameters to use. :returns: A model with set parameters. .. py:method:: init_params(rng) Initialise the parameters of the model. :param rng: PRNGkey to initialise the parameters. :returns: The initialised parameters. .. py:method:: __hash__() .. py:method:: sample_U(key) Sample from the prior model. :param key: PRNGKey to use. :returns: The sampled U. .. py:method:: sample_W(key) Sample from the prior model. :param key: PRNGKey to use. :returns: The sampled W. .. py:method:: transform(U) .. py:method:: transform_parametrised(U) .. py:method:: forward(U, allow_nan = False) .. py:method:: log_prob_prior(U) .. py:method:: prepare_input(U) .. py:method:: sanity_check(key, S) .. py:class:: Prior(dist_or_value, name = None) Bases: :py:obj:`jaxns.framework.bases.BaseAbstractPrior` Represents a generative prior. .. py:attribute:: name :value: None .. py:property:: dist :type: jaxns.framework.bases.BaseAbstractDistribution .. py:property:: value :type: jax.Array .. py:method:: parametrised(random_init = False) Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a `get_parameter` with added `_param` name suffix. Prior must have a name. :param random_init: whether to initialise the parameter randomly or at the median of the distribution. :returns: A singular prior. :raises ValueError: if the prior has no name. .. py:exception:: InvalidPriorName(name = None) Bases: :py:obj:`Exception` Raised when a prior name is already taken. Initialize self. See help(type(self)) for accurate signature. .. py:class:: Bernoulli(*, logits=None, probs=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:attribute:: dist .. py:class:: Beta(*, concentration0=None, concentration1=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:class:: Categorical(parametrisation, *, logits=None, probs=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. Initialised Categorical special prior. :param parametrisation: 'cdf' is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better. :param logits: log-prob of each category :param probs: prob of each category :param name: optional name .. py:attribute:: dist .. py:class:: ForcedIdentifiability(*, n, low=None, high=None, fix_left = False, fix_right = False, name = None) Bases: :py:obj:`SpecialPrior` Prior for a sequence of `n` random variables uniformly distributed on U[low, high] such that U[i,...] <= U[i+1,...]. For broadcasting the resulting random variable is sorted on the first dimension elementwise. :param n: number of samples within [low,high] :param low: minimum of distribution :param high: maximum of distribution :param fix_left: if True, the leftmost value is fixed to `low` :param fix_right: if True, the rightmost value is fixed to `high` .. py:attribute:: n .. py:attribute:: low .. py:attribute:: high .. py:attribute:: fix_left :value: False .. py:attribute:: fix_right :value: False .. py:class:: Poisson(*, rate=None, log_rate=None, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:attribute:: dist .. py:class:: UnnormalisedDirichlet(*, concentration, name = None) Bases: :py:obj:`SpecialPrior` Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation. X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha) .. py:class:: Empirical(*, samples, support_min = None, support_max = None, resolution = 100, name = None) Bases: :py:obj:`SpecialPrior` Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension. .. py:class:: TruncationWrapper(prior, low, high, name = None) Bases: :py:obj:`SpecialPrior` Wraps another prior to make it truncated. For truncated distribution the quantile transforms to: Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low)) And the CDF transforms to: F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low)) .. py:attribute:: prior .. py:attribute:: low .. py:attribute:: high .. py:attribute:: cdf_low .. py:attribute:: cdf_diff .. py:class:: ExplicitDensityPrior(*, axes, density, regular_grid = False, name = None) Bases: :py:obj:`SpecialPrior` The base prior class with public methods. .. py:function:: plot_diagnostics(results, save_name=None) Plot diagnostics of the nested sampling run. :param results: NestedSamplerResult :param save_name: file to save figure to. .. py:function:: plot_cornerplot(results, variables = None, with_parametrised = False, save_name = None, kde_overlay = False) Plots a cornerplot of the posterior samples. :param results: NestedSamplerResult :param variables: list of variable names to plot. Plots all collected samples by default. :param save_name: file to save result to. :param kde_overlay: whether to overlay a KDE on the histograms. .. py:class:: NestedSampler A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters. s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330 :param model: a model to perform nested sampling on :param max_samples: maximum number of samples to take :param num_live_points: approximate number of live points to use. Defaults is c * (k + 1). :param s: number of slices to use per dimension. Defaults to 4. :param k: number of phantom samples to use. Defaults to 0. :param c: number of parallel Markov-chains to use. Defaults to 20 * D. :param devices: devices to use. Defaults to all available devices. :param difficult_model: if True, uses more robust default settings. Defaults to False. :param parameter_estimation: if True, uses more robust default settings for parameter estimation. Defaults to False. :param shell_fraction: fraction of the shell to use for the slice sampler. Defaults to 0.5. :param gradient_guided: if True, uses gradient guided sampling. Defaults to False. :param init_efficiency_threshold: if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off. :param verbose: whether to log progress. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: max_samples :type: Optional[Union[int, float]] :value: None .. py:attribute:: num_live_points :type: Optional[int] :value: None .. py:attribute:: num_slices :type: Optional[int] :value: None .. py:attribute:: s :type: Optional[Union[int, float]] :value: None .. py:attribute:: k :type: Optional[int] :value: None .. py:attribute:: c :type: Optional[int] :value: None .. py:attribute:: devices :type: Optional[List[jaxlib.xla_client.Device]] :value: None .. py:attribute:: difficult_model :type: bool :value: False .. py:attribute:: parameter_estimation :type: bool :value: False .. py:attribute:: shell_fraction :type: float :value: 0.5 .. py:attribute:: gradient_guided :type: bool :value: False .. py:attribute:: init_efficiency_threshold :type: float :value: 0.1 .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:property:: nested_sampler :type: jaxns.nested_samplers.abc.AbstractNestedSampler .. py:method:: __call__(key, term_cond = None) Performs nested sampling with the given termination conditions. :param key: PRNGKey :param term_cond: termination conditions. If not given, see `TerminationCondition` for defaults. :returns: termination reason, state .. py:method:: to_results(termination_reason, state, trim = True) Convert the state to results. Note: Requires static context. :param termination_reason: termination reason :param state: state to convert :param trim: if True, trims the results to the number of samples taken, requires static context. :returns: results .. py:method:: trim_results(results) :staticmethod: Trims the results to the number of samples taken. Requires static context. :param results: results to trim :returns: trimmed results .. py:function:: resample(key, samples, log_weights, S = None, replace = True) Resample the weighted samples into uniformly weighted samples. :param key: PRNGKey :param samples: samples from nested sampled results :param log_weights: log-posterior weight :param S: number of samples to generate. Will use Kish's estimate of ESS if None. :param replace: whether to sample with replacement :returns: equally weighted samples .. py:function:: marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun) Marginalises function over posterior samples, where ESS is static. :param key: PRNG key :param U_samples: array of U samples :param model: model :param log_weights: log weights from nested sampling :param ESS: static effective sample size :param fun: function to marginalise :type fun: :code:`callable(**kwargs)` :returns: expectation over resampled samples .. py:function:: marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun) Marginalises function over posterior samples, where ESS can be dynamic. :param key: PRNG key :param U_samples: array of U samples :param model: model :param log_weights: log weights from nested sampling :param ESS: dynamic effective sample size :param fun: function to marginalise :type fun: :code:`callable(**kwargs)` :returns: expectation of `func` over resampled samples. .. py:function:: marginalise_static(key, samples, log_weights, ESS, fun) Marginalises function over posterior samples, where ESS is static. :param key: PRNG key :param samples: dict of batched array of nested sampling samples :type samples: dict :param log_weights: log weights from nested sampling :param ESS: static effective sample size :param fun: function to marginalise :type fun: :code:`callable(**kwargs)` :returns: expectation over resampled samples .. py:function:: marginalise_dynamic(key, samples, log_weights, ESS, fun) Marginalises function over posterior samples, where ESS can be dynamic. :param key: PRNG key :param samples: dict of batched array of nested sampling samples :type samples: dict :param log_weights: log weights from nested sampling :param ESS: dynamic effective sample size :param fun: function to marginalise :type fun: :code:`callable(**kwargs)` :returns: expectation of `func` over resampled samples. .. py:function:: maximum_a_posteriori_point(results) Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x). :param results: Nested sampler result :type results: NestedSamplerResult :returns: dict of samples at MAP-point. .. py:function:: evaluate_map_estimate(results, fun) Marginalises function over posterior samples, where ESS is static. :param results: results from run :param fun: function to marginalise :type fun: :code:`callable(**kwargs)` :returns: estimate at MAP sample point .. py:function:: summary(results, with_parametrised = False, f_obj = None) Gives a summary of the results of a nested sampling run. :param results: Nested sampler result :type results: NestedSamplerResults :param with_parametrised: whether to include parametrised samples :param f_obj: file-like object to write summary to. If None, prints to stdout. .. py:function:: analytic_posterior_samples(model, S = 60) Compute the evidence with brute-force over a regular grid. :param model: model :param S: resolution of grid :returns: log(Z) .. py:function:: sample_evidence(key, num_live_points_per_sample, log_L_samples, S = 100) Sample the evidence distribution, but stochastically simulating the shrinkage distribution. Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times. :param key: PRNGKey :param num_live_points_per_sample: the number of live points for each sample :param log_L_samples: the log-L of samples :param S: The number of samples to produce :returns: samples of log(Z) .. py:function:: bruteforce_posterior_samples(model, S = 60) Compute the posterior with brute-force over a regular grid. :param model: model :param S: resolution of grid :returns: samples, and log-weight .. py:function:: bruteforce_evidence(model, S = 60) Compute the evidence with brute-force over a regular grid. :param model: model :param S: resolution of grid :returns: log(Z) .. py:function:: save_pytree(pytree, save_file) Saves results of nested sampler in a json file. :param pytree: Nested sampler result :param save_file: filename .. py:function:: save_results(results, save_file) Saves results of nested sampler in a json file. :param results: Nested sampler result :type results: NestedSamplerResults :param save_file: filename :type save_file: str .. py:function:: load_pytree(save_file) Loads saved nested sampler results from a json file. :param save_file: filename :type save_file: str :returns: NestedSamplerResults .. py:function:: load_results(save_file) Loads saved nested sampler results from a json file. :param save_file: filename :type save_file: str :returns: NestedSamplerResults .. py:class:: ShardedStaticNestedSampler Bases: :py:obj:`jaxns.nested_samplers.abc.AbstractNestedSampler` A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met. :param init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off. :param sampler: the sampler to use after the initial uniform sampling. :param num_live_points: the number of live points to use. :param model: the model to use. :param max_samples: the maximum number of samples to take. :param devices: the devices to use, default is 1. :param verbose: whether to log as we go. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: max_samples :type: int .. py:attribute:: init_efficiency_threshold :type: float .. py:attribute:: sampler :type: jaxns.samplers.abc.AbstractSampler .. py:attribute:: num_live_points :type: int .. py:attribute:: shell_fraction :type: Optional[float] :value: None .. py:attribute:: num_dynamic_refinement_iterations :type: int :value: 0 .. py:attribute:: refine_threshold :type: float :value: 0.01 .. py:attribute:: devices :type: Optional[List[jaxlib.xla_client.Device]] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:class:: TerminationCondition Bases: :py:obj:`NamedTuple` Contains the termination conditions for the nested sampling run. :param ess: The effective sample size, if the ESS (Kish's estimate) is greater than this the run will terminate. :param evidence_uncert: The uncertainty in the evidence, if the uncertainty is less than this the run will terminate. :param live_evidence_frac: Depreceated use dlogZ. :param dlogZ: Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2) :param max_samples: Terminate if the number of samples exceeds this. :param max_num_likelihood_evaluations: Terminate if the number of likelihood evaluations exceeds this. :param log_L_contour: Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered. :param efficiency_threshold: Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration. :param rtol: finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol :param atol: finish when the absolute |log_L_max - log_L_min| < atol .. py:attribute:: ess :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: evidence_uncert :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: live_evidence_frac :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: dlogZ :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: max_samples :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: max_num_likelihood_evaluations :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: log_L_contour :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: efficiency_threshold :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: rtol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: atol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: peak_XL_frac :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:method:: __and__(other) .. py:method:: __or__(other) .. py:class:: NestedSamplerResults Bases: :py:obj:`NamedTuple` Results of the nested sampling run. .. py:attribute:: log_Z_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_Z_uncert :type: jaxns.internals.types.FloatArray .. py:attribute:: ESS :type: jaxns.internals.types.FloatArray .. py:attribute:: H_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: samples :type: jaxns.internals.types.XType .. py:attribute:: parametrised_samples :type: jaxns.internals.types.XType .. py:attribute:: U_samples :type: jaxns.internals.types.UType .. py:attribute:: log_L_samples :type: jaxns.internals.types.FloatArray .. py:attribute:: log_dp_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_X_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_posterior_density :type: jaxns.internals.types.FloatArray .. py:attribute:: num_live_points_per_sample :type: jaxns.internals.types.IntArray .. py:attribute:: num_likelihood_evaluations_per_sample :type: jaxns.internals.types.IntArray .. py:attribute:: total_num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: total_phantom_samples :type: jaxns.internals.types.IntArray .. py:attribute:: total_num_likelihood_evaluations :type: jaxns.internals.types.IntArray .. py:attribute:: log_efficiency :type: jaxns.internals.types.FloatArray .. py:attribute:: termination_reason :type: jaxns.internals.types.IntArray .. py:class:: NestedSamplerState Bases: :py:obj:`NamedTuple` .. py:attribute:: key :type: jaxns.internals.types.PRNGKey .. py:attribute:: next_sample_idx :type: jaxns.internals.types.IntArray .. py:attribute:: num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: sample_collection :type: StaticStandardSampleCollection