nested_samplers ========================= .. py:module:: jaxns.nested_samplers .. rubric:: :code:`jaxns.nested_samplers` .. rubric:: Subpackages .. toctree:: :titlesonly: :maxdepth: 1 common/index.rst sharded/index.rst .. rubric:: Submodules .. toctree:: :titlesonly: :maxdepth: 1 abc/index.rst .. rubric:: Package Contents .. py:class:: ShardedStaticNestedSampler Bases: :py:obj:`jaxns.nested_samplers.abc.AbstractNestedSampler` A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met. :param init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off. :param sampler: the sampler to use after the initial uniform sampling. :param num_live_points: the number of live points to use. :param model: the model to use. :param max_samples: the maximum number of samples to take. :param devices: the devices to use, default is 1. :param verbose: whether to log as we go. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: max_samples :type: int .. py:attribute:: init_efficiency_threshold :type: float .. py:attribute:: sampler :type: jaxns.samplers.abc.AbstractSampler .. py:attribute:: num_live_points :type: int .. py:attribute:: shell_fraction :type: Optional[float] :value: None .. py:attribute:: num_dynamic_refinement_iterations :type: int :value: 0 .. py:attribute:: refine_threshold :type: float :value: 0.01 .. py:attribute:: devices :type: Optional[List[jaxlib.xla_client.Device]] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:class:: TerminationCondition Bases: :py:obj:`NamedTuple` Contains the termination conditions for the nested sampling run. :param ess: The effective sample size, if the ESS (Kish's estimate) is greater than this the run will terminate. :param evidence_uncert: The uncertainty in the evidence, if the uncertainty is less than this the run will terminate. :param live_evidence_frac: Depreceated use dlogZ. :param dlogZ: Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2) :param max_samples: Terminate if the number of samples exceeds this. :param max_num_likelihood_evaluations: Terminate if the number of likelihood evaluations exceeds this. :param log_L_contour: Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered. :param efficiency_threshold: Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration. :param rtol: finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol :param atol: finish when the absolute |log_L_max - log_L_min| < atol .. py:attribute:: ess :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: evidence_uncert :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: live_evidence_frac :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: dlogZ :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: max_samples :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: max_num_likelihood_evaluations :type: Optional[Union[jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray]] :value: None .. py:attribute:: log_L_contour :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: efficiency_threshold :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: rtol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: atol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: peak_XL_frac :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:method:: __and__(other) .. py:method:: __or__(other) .. py:class:: NestedSamplerResults Bases: :py:obj:`NamedTuple` Results of the nested sampling run. .. py:attribute:: log_Z_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_Z_uncert :type: jaxns.internals.types.FloatArray .. py:attribute:: ESS :type: jaxns.internals.types.FloatArray .. py:attribute:: H_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: samples :type: jaxns.internals.types.XType .. py:attribute:: parametrised_samples :type: jaxns.internals.types.XType .. py:attribute:: U_samples :type: jaxns.internals.types.UType .. py:attribute:: log_L_samples :type: jaxns.internals.types.FloatArray .. py:attribute:: log_dp_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_X_mean :type: jaxns.internals.types.FloatArray .. py:attribute:: log_posterior_density :type: jaxns.internals.types.FloatArray .. py:attribute:: num_live_points_per_sample :type: jaxns.internals.types.IntArray .. py:attribute:: num_likelihood_evaluations_per_sample :type: jaxns.internals.types.IntArray .. py:attribute:: total_num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: total_phantom_samples :type: jaxns.internals.types.IntArray .. py:attribute:: total_num_likelihood_evaluations :type: jaxns.internals.types.IntArray .. py:attribute:: log_efficiency :type: jaxns.internals.types.FloatArray .. py:attribute:: termination_reason :type: jaxns.internals.types.IntArray .. py:class:: NestedSamplerState Bases: :py:obj:`NamedTuple` .. py:attribute:: key :type: jaxns.internals.types.PRNGKey .. py:attribute:: next_sample_idx :type: jaxns.internals.types.IntArray .. py:attribute:: num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: sample_collection :type: StaticStandardSampleCollection