public
jaxns.public
Module Contents
- class NestedSampler[source]
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model – a model to perform nested sampling on
max_samples – maximum number of samples to take
num_live_points – approximate number of live points to use. Defaults is c * (k + 1).
s – number of slices to use per dimension. Defaults to 4.
k – number of phantom samples to use. Defaults to 0.
c – number of parallel Markov-chains to use. Defaults to 20 * D.
devices – devices to use. Defaults to all available devices.
difficult_model – if True, uses more robust default settings. Defaults to False.
parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.
shell_fraction – fraction of the shell to use for the slice sampler. Defaults to 0.5.
gradient_guided – if True, uses gradient guided sampling. Defaults to False.
init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose – whether to log progress.
- property nested_sampler: jaxns.nested_samplers.abc.AbstractNestedSampler[source]
- Return type:
- __call__(key, term_cond=None)[source]
Performs nested sampling with the given termination conditions.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
term_cond (Optional[jaxns.nested_samplers.common.types.TerminationCondition]) – termination conditions. If not given, see TerminationCondition for defaults.
- Returns:
termination reason, state
- Return type:
Tuple[jaxns.internals.types.IntArray, jaxns.nested_samplers.common.types.NestedSamplerState]
- to_results(termination_reason, state, trim=True)[source]
Convert the state to results.
Note: Requires static context.
- Parameters:
termination_reason (jaxns.internals.types.IntArray) – termination reason
state (jaxns.nested_samplers.common.types.NestedSamplerState) – state to convert
trim (bool) – if True, trims the results to the number of samples taken, requires static context.
- Returns:
results
- Return type:
- static trim_results(results)[source]
Trims the results to the number of samples taken. Requires static context.
- Parameters:
results (jaxns.nested_samplers.common.types.NestedSamplerResults) – results to trim
- Returns:
trimmed results
- Return type: