public
jaxns.public
Module Contents
- class TerminationCondition[source]
Bases:
NamedTuple
Contains the termination conditions for the nested sampling run.
- Parameters:
ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.
evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.
live_evidence_frac – Depreceated use dlogZ.
dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)
max_samples – Terminate if the number of samples exceeds this.
max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.
log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.
efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.
- class DefaultNestedSampler(model, max_samples, num_live_points=None, s=None, k=None, c=None, num_parallel_workers=1, difficult_model=False, parameter_estimation=False, init_efficiency_threshold=0.1, verbose=False)[source]
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – a model to perform nested sampling on
max_samples (Union[int, float]) – maximum number of samples to take
num_live_points (Optional[int]) – approximate number of live points to use. Defaults is c * (k + 1).
s (Optional[int]) – number of slices to use per dimension. Defaults to 4.
k (Optional[int]) – number of phantom samples to use. Defaults to 0.
c (Optional[int]) – number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers (int) – number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model (bool) – if True, uses more robust default settings. Defaults to False.
parameter_estimation (bool) – if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold (float) – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose (bool) – whether to use JAX printing
- property nested_sampler: jaxns.nested_sampler.bases.BaseAbstractNestedSampler[source]
- Return type:
- __call__(key, term_cond=None)[source]
Performs nested sampling with the given termination conditions.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
term_cond (Optional[jaxns.internals.types.TerminationCondition]) – termination conditions. If not given, see TerminationCondition for defaults.
- Returns:
termination reason, state
- Return type:
Tuple[jaxns.internals.types.IntArray, jaxns.internals.types.StaticStandardNestedSamplerState]
- to_results(termination_reason, state, trim=True)[source]
Convert the state to results.
Note: Requires static context.
- Parameters:
termination_reason (jaxns.internals.types.IntArray) – termination reason
state (jaxns.internals.types.StaticStandardNestedSamplerState) – state to convert
trim (bool) – if True, trims the results to the number of samples taken, requires static context.
- Returns:
results
- Return type:
- static trim_results(results)[source]
Trims the results to the number of samples taken. Requires static context.
- Parameters:
results (jaxns.internals.types.NestedSamplerResults) – results to trim
- Returns:
trimmed results
- Return type:
- class ApproximateNestedSampler(*args, **kwargs)[source]
Bases:
DefaultNestedSampler
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model – a model to perform nested sampling on
max_samples – maximum number of samples to take
num_live_points – approximate number of live points to use. Defaults is c * (k + 1).
s – number of slices to use per dimension. Defaults to 4.
k – number of phantom samples to use. Defaults to 0.
c – number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers – number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model – if True, uses more robust default settings. Defaults to False.
parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose – whether to use JAX printing
- class ExactNestedSampler(*args, **kwargs)[source]
Bases:
ApproximateNestedSampler
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model – a model to perform nested sampling on
max_samples – maximum number of samples to take
num_live_points – approximate number of live points to use. Defaults is c * (k + 1).
s – number of slices to use per dimension. Defaults to 4.
k – number of phantom samples to use. Defaults to 0.
c – number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers – number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model – if True, uses more robust default settings. Defaults to False.
parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose – whether to use JAX printing