nested_samplers

jaxns.nested_samplers

Subpackages

Submodules

Package Contents

class ShardedStaticNestedSampler[source]

Bases: jaxns.nested_samplers.abc.AbstractNestedSampler

A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met.

Parameters:
  • init_efficiency_threshold – the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off.

  • sampler – the sampler to use after the initial uniform sampling.

  • num_live_points – the number of live points to use.

  • model – the model to use.

  • max_samples – the maximum number of samples to take.

  • devices – the devices to use, default is 1.

  • verbose – whether to log as we go.

model: jaxns.framework.bases.BaseAbstractModel
max_samples: int
init_efficiency_threshold: float
sampler: jaxns.samplers.abc.AbstractSampler
num_live_points: int
shell_fraction: float | None = None
num_dynamic_refinement_iterations: int = 0
refine_threshold: float = 0.01
devices: List[jaxlib.xla_client.Device] | None = None
verbose: bool = False
__post_init__()[source]
class TerminationCondition[source]

Bases: NamedTuple

Contains the termination conditions for the nested sampling run.

Parameters:
  • ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.

  • evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.

  • live_evidence_frac – Depreceated use dlogZ.

  • dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)

  • max_samples – Terminate if the number of samples exceeds this.

  • max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.

  • log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.

  • efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.

  • rtol – finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol

  • atol – finish when the absolute |log_L_max - log_L_min| < atol

ess: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
evidence_uncert: jaxns.internals.types.FloatArray | None = None
live_evidence_frac: jaxns.internals.types.FloatArray | None = None
dlogZ: jaxns.internals.types.FloatArray | None = None
max_samples: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
max_num_likelihood_evaluations: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
log_L_contour: jaxns.internals.types.FloatArray | None = None
efficiency_threshold: jaxns.internals.types.FloatArray | None = None
rtol: jaxns.internals.types.FloatArray | None = None
atol: jaxns.internals.types.FloatArray | None = None
peak_XL_frac: jaxns.internals.types.FloatArray | None = None
__and__(other)[source]
__or__(other)[source]
class NestedSamplerResults[source]

Bases: NamedTuple

Results of the nested sampling run.

log_Z_mean: jaxns.internals.types.FloatArray
log_Z_uncert: jaxns.internals.types.FloatArray
ESS: jaxns.internals.types.FloatArray
H_mean: jaxns.internals.types.FloatArray
samples: jaxns.internals.types.XType
parametrised_samples: jaxns.internals.types.XType
U_samples: jaxns.internals.types.UType
log_L_samples: jaxns.internals.types.FloatArray
log_dp_mean: jaxns.internals.types.FloatArray
log_X_mean: jaxns.internals.types.FloatArray
log_posterior_density: jaxns.internals.types.FloatArray
num_live_points_per_sample: jaxns.internals.types.IntArray
num_likelihood_evaluations_per_sample: jaxns.internals.types.IntArray
total_num_samples: jaxns.internals.types.IntArray
total_phantom_samples: jaxns.internals.types.IntArray
total_num_likelihood_evaluations: jaxns.internals.types.IntArray
log_efficiency: jaxns.internals.types.FloatArray
termination_reason: jaxns.internals.types.IntArray
class NestedSamplerState[source]

Bases: NamedTuple

key: jaxns.internals.types.PRNGKey
next_sample_idx: jaxns.internals.types.IntArray
num_samples: jaxns.internals.types.IntArray
sample_collection: StaticStandardSampleCollection