samplers ================== .. py:module:: jaxns.samplers .. rubric:: :code:`jaxns.samplers` .. rubric:: Subpackages .. toctree:: :titlesonly: :maxdepth: 1 multi_ellipsoid/index.rst .. rubric:: Submodules .. toctree:: :titlesonly: :maxdepth: 1 abc/index.rst bases/index.rst multi_ellipsoidal_samplers/index.rst multi_slice_sampler/index.rst uni_slice_sampler/index.rst uniform_samplers/index.rst .. rubric:: Package Contents .. py:class:: MultiEllipsoidalSampler Bases: :py:obj:`jaxns.samplers.bases.BaseAbstractRejectionSampler`\ [\ :py:obj:`jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils.MultEllipsoidState`\ ] Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from. Inefficient for high dimensional problems, but can be very efficient for low dimensional problems. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: depth :type: int .. py:attribute:: expansion_factor :type: float .. py:method:: __post_init__() .. py:method:: num_phantom() .. py:property:: max_num_ellipsoids .. py:class:: MultiDimSliceSampler Bases: :py:obj:`jaxns.samplers.bases.BaseAbstractMarkovSampler`\ [\ :py:obj:`jaxns.nested_samplers.common.types.SampleCollection`\ ] Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples. Notes: Not very efficient. :param model: AbstractModel :param num_slices: number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension. :param num_phantom_save: number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty. :param num_restrict_dims: size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: num_slices :type: int .. py:attribute:: num_phantom_save :type: int .. py:attribute:: num_restrict_dims :type: Optional[int] :value: None .. py:method:: __post_init__() .. py:method:: num_phantom() .. py:method:: get_seed_point(key, sampler_state, log_L_constraint) .. py:method:: get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state) .. py:class:: UniDimSliceSampler Bases: :py:obj:`jaxns.samplers.bases.BaseAbstractMarkovSampler`\ [\ :py:obj:`jaxns.nested_samplers.common.types.SampleCollection`\ ] Slice sampler for a single dimension. :param model: AbstractModel :param num_slices: number of slices between acceptance. Note: some other software use units of prior dimension. :param midpoint_shrink: if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation. :param num_phantom_save: number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty. :param perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance. :param gradient_slice: if true then always slice along increasing gradient direction. :param adaptive_shrink: if true then shrink interval to random point in interval, rather than midpoint. :param gradient_guided: if true then do householder reflections at between proposals with a 50% probability. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: num_slices :type: int .. py:attribute:: num_phantom_save :type: int .. py:attribute:: midpoint_shrink :type: bool .. py:attribute:: perfect :type: bool .. py:attribute:: gradient_slice :type: bool :value: False .. py:attribute:: adaptive_shrink :type: bool :value: False .. py:attribute:: gradient_guided :type: bool :value: False .. py:method:: __post_init__() .. py:method:: num_phantom() .. py:method:: get_seed_point(key, sampler_state, log_L_constraint) .. py:method:: get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state) .. py:class:: UniformSampler Bases: :py:obj:`jaxns.samplers.bases.BaseAbstractRejectionSampler`\ [\ :py:obj:`Tuple`\ ] A sampler that produces uniform samples from the model within the log_L_constraint. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: max_likelihood_evals :type: int :value: 100 .. py:method:: __post_init__() .. py:method:: num_phantom() .. py:class:: ShardedStaticNestedSampler Bases: :py:obj:`jaxns.nested_samplers.abc.AbstractNestedSampler` A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met. :param init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off. :param sampler: the sampler to use after the initial uniform sampling. :param num_live_points: the number of live points to use. :param model: the model to use. :param max_samples: the maximum number of samples to take. :param devices: the devices to use, default is 1. :param verbose: whether to log as we go. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: max_samples :type: int .. py:attribute:: init_efficiency_threshold :type: float .. py:attribute:: sampler :type: jaxns.samplers.abc.AbstractSampler .. py:attribute:: num_live_points :type: int .. py:attribute:: shell_fraction :type: Optional[float] :value: None .. py:attribute:: num_dynamic_refinement_iterations :type: int :value: 0 .. py:attribute:: refine_threshold :type: float :value: 0.01 .. py:attribute:: devices :type: Optional[List[jaxlib.xla_client.Device]] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:class:: TerminationCondition Bases: :py:obj:`NamedTuple` Contains the termination conditions for the nested sampling run. :param ess: The effective sample size, if the ESS (Kish's estimate) is greater than this the run will terminate. :param evidence_uncert: The uncertainty in the evidence, if the uncertainty is less than this the run will terminate. :param live_evidence_frac: Depreceated use dlogZ. :param dlogZ: Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2) :param max_samples: Terminate if the number of samples exceeds this. :param max_num_likelihood_evaluations: Terminate if the number of likelihood evaluations exceeds this. :param log_L_contour: Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered. :param efficiency_threshold: Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration. :param rtol: finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol :param atol: finish when the absolute |log_L_max - log_L_min| < atol .. py:attribute:: ess :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: evidence_uncert :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: live_evidence_frac :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: dlogZ :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: max_samples :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: max_num_likelihood_evaluations :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: log_L_contour :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: efficiency_threshold :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: rtol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: atol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: peak_XL_frac :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:method:: __and__(other) .. py:method:: __or__(other) .. py:class:: NestedSamplerResults Bases: :py:obj:`NamedTuple` Results of the nested sampling run. .. py:attribute:: log_Z_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_Z_uncert :type: jaxns.internals.types.FloatArray .. py:attribute:: ESS :type: jaxns.internals.types.FloatArray .. py:attribute:: H_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: samples :type: jaxns.internals.types.XType .. py:attribute:: parametrised_samples :type: jaxns.internals.types.XType .. py:attribute:: U_samples :type: jaxns.internals.types.UType .. py:attribute:: log_L_samples :type: jaxns.internals.types.FloatArray .. py:attribute:: log_dp_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_X_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_posterior_density :type: jaxns.internals.types.FloatArray .. py:attribute:: num_live_points_per_sample :type: jaxns.internals.types.IntArray .. py:attribute:: num_likelihood_evaluations_per_sample :type: jaxns.internals.types.IntArray .. py:attribute:: total_num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: total_phantom_samples :type: jaxns.internals.types.IntArray .. py:attribute:: total_num_likelihood_evaluations :type: jaxns.internals.types.IntArray .. py:attribute:: log_efficiency :type: jaxns.internals.types.FloatArray .. py:attribute:: termination_reason :type: jaxns.internals.types.IntArray .. py:class:: NestedSamplerState Bases: :py:obj:`NamedTuple` .. py:attribute:: key :type: jaxns.internals.types.PRNGKey .. py:attribute:: next_sample_idx :type: jaxns.internals.types.IntArray .. py:attribute:: num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: sample_collection :type: StaticStandardSampleCollection