samplers

jaxns.samplers

Subpackages

Submodules

Package Contents

class MultiEllipsoidalSampler[source]

Bases: jaxns.samplers.bases.BaseAbstractRejectionSampler[jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils.MultEllipsoidState]

Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from.

Inefficient for high dimensional problems, but can be very efficient for low dimensional problems.

model: jaxns.framework.bases.BaseAbstractModel
depth: int
expansion_factor: float
__post_init__()[source]
num_phantom()[source]
Return type:

int

property max_num_ellipsoids
class MultiDimSliceSampler[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler[jaxns.nested_samplers.common.types.SampleCollection]

Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples.

Notes: Not very efficient.

Parameters:
  • model – AbstractModel

  • num_slices – number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension.

  • num_phantom_save – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • num_restrict_dims – size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient.

model: jaxns.framework.bases.BaseAbstractModel
num_slices: int
num_phantom_save: int
num_restrict_dims: int | None = None
__post_init__()[source]
num_phantom()[source]
Return type:

int

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • sampler_state (jaxns.nested_samplers.common.types.LivePointCollection)

  • log_L_constraint (jaxns.internals.types.FloatArray)

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • seed_point (jaxns.samplers.bases.SeedPoint)

  • log_L_constraint (jaxns.internals.types.FloatArray)

  • sampler_state (jaxns.nested_samplers.common.types.SampleCollection)

Return type:

Tuple[jaxns.nested_samplers.common.types.Sample, jaxns.nested_samplers.common.types.Sample]

class UniDimSliceSampler[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler[jaxns.nested_samplers.common.types.SampleCollection]

Slice sampler for a single dimension.

Parameters:
  • model – AbstractModel

  • num_slices – number of slices between acceptance. Note: some other software use units of prior dimension.

  • midpoint_shrink – if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation.

  • num_phantom_save – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • perfect – if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.

  • gradient_slice – if true then always slice along increasing gradient direction.

  • adaptive_shrink – if true then shrink interval to random point in interval, rather than midpoint.

  • gradient_guided – if true then do householder reflections at between proposals with a 50% probability.

model: jaxns.framework.bases.BaseAbstractModel
num_slices: int
num_phantom_save: int
midpoint_shrink: bool
perfect: bool
gradient_slice: bool = False
adaptive_shrink: bool = False
gradient_guided: bool = False
__post_init__()[source]
num_phantom()[source]
Return type:

int

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • sampler_state (jaxns.nested_samplers.common.types.LivePointCollection)

  • log_L_constraint (jaxns.internals.types.FloatArray)

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • seed_point (jaxns.samplers.bases.SeedPoint)

  • log_L_constraint (jaxns.internals.types.FloatArray)

  • sampler_state (jaxns.nested_samplers.common.types.SampleCollection)

Return type:

Tuple[jaxns.nested_samplers.common.types.Sample, jaxns.nested_samplers.common.types.Sample]

class UniformSampler[source]

Bases: jaxns.samplers.bases.BaseAbstractRejectionSampler[Tuple]

A sampler that produces uniform samples from the model within the log_L_constraint.

model: jaxns.framework.bases.BaseAbstractModel
max_likelihood_evals: int = 100
__post_init__()[source]
num_phantom()[source]
Return type:

int

class ShardedStaticNestedSampler[source]

Bases: jaxns.nested_samplers.abc.AbstractNestedSampler

A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met.

Parameters:
  • init_efficiency_threshold – the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off.

  • sampler – the sampler to use after the initial uniform sampling.

  • num_live_points – the number of live points to use.

  • model – the model to use.

  • max_samples – the maximum number of samples to take.

  • devices – the devices to use, default is 1.

  • verbose – whether to log as we go.

model: jaxns.framework.bases.BaseAbstractModel
max_samples: int
init_efficiency_threshold: float
sampler: jaxns.samplers.abc.AbstractSampler
num_live_points: int
shell_fraction: float | None = None
num_dynamic_refinement_iterations: int = 0
refine_threshold: float = 0.01
devices: List[jaxlib.xla_client.Device] | None = None
verbose: bool = False
__post_init__()[source]
class TerminationCondition[source]

Bases: NamedTuple

Contains the termination conditions for the nested sampling run.

Parameters:
  • ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.

  • evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.

  • live_evidence_frac – Depreceated use dlogZ.

  • dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)

  • max_samples – Terminate if the number of samples exceeds this.

  • max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.

  • log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.

  • efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.

  • rtol – finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol

  • atol – finish when the absolute |log_L_max - log_L_min| < atol

ess: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
evidence_uncert: jaxns.internals.types.FloatArray | None = None
live_evidence_frac: jaxns.internals.types.FloatArray | None = None
dlogZ: jaxns.internals.types.FloatArray | None = None
max_samples: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
max_num_likelihood_evaluations: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
log_L_contour: jaxns.internals.types.FloatArray | None = None
efficiency_threshold: jaxns.internals.types.FloatArray | None = None
rtol: jaxns.internals.types.FloatArray | None = None
atol: jaxns.internals.types.FloatArray | None = None
peak_XL_frac: jaxns.internals.types.FloatArray | None = None
__and__(other)[source]
__or__(other)[source]
class NestedSamplerResults[source]

Bases: NamedTuple

Results of the nested sampling run.

log_Z_mean: jaxns.internals.types.FloatArray
log_Z_uncert: jaxns.internals.types.FloatArray
ESS: jaxns.internals.types.FloatArray
H_mean: jaxns.internals.types.FloatArray
samples: jaxns.internals.types.XType
parametrised_samples: jaxns.internals.types.XType
U_samples: jaxns.internals.types.UType
log_L_samples: jaxns.internals.types.FloatArray
log_dp_mean: jaxns.internals.types.FloatArray
log_X_mean: jaxns.internals.types.FloatArray
log_posterior_density: jaxns.internals.types.FloatArray
num_live_points_per_sample: jaxns.internals.types.IntArray
num_likelihood_evaluations_per_sample: jaxns.internals.types.IntArray
total_num_samples: jaxns.internals.types.IntArray
total_phantom_samples: jaxns.internals.types.IntArray
total_num_likelihood_evaluations: jaxns.internals.types.IntArray
log_efficiency: jaxns.internals.types.FloatArray
termination_reason: jaxns.internals.types.IntArray
class NestedSamplerState[source]

Bases: NamedTuple

key: jaxns.internals.types.PRNGKey
next_sample_idx: jaxns.internals.types.IntArray
num_samples: jaxns.internals.types.IntArray
sample_collection: StaticStandardSampleCollection