public

jaxns.public

Module Contents

class NestedSampler[source]

A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.

s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330

Parameters:
  • model – a model to perform nested sampling on

  • max_samples – maximum number of samples to take

  • num_live_points – approximate number of live points to use. Defaults is c * (k + 1).

  • s – number of slices to use per dimension. Defaults to 4.

  • k – number of phantom samples to use. Defaults to 0.

  • c – number of parallel Markov-chains to use. Defaults to 20 * D.

  • devices – devices to use. Defaults to all available devices.

  • difficult_model – if True, uses more robust default settings. Defaults to False.

  • parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.

  • shell_fraction – fraction of the shell to use for the slice sampler. Defaults to 0.5.

  • gradient_guided – if True, uses gradient guided sampling. Defaults to False.

  • init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.

  • verbose – whether to log progress.

model: jaxns.framework.bases.BaseAbstractModel[source]
max_samples: int | float | None = None[source]
num_live_points: int | None = None[source]
num_slices: int | None = None[source]
s: int | float | None = None[source]
k: int | None = None[source]
c: int | None = None[source]
devices: List[jaxlib.xla_client.Device] | None = None[source]
difficult_model: bool = False[source]
parameter_estimation: bool = False[source]
shell_fraction: float = 0.5[source]
gradient_guided: bool = False[source]
init_efficiency_threshold: float = 0.1[source]
verbose: bool = False[source]
__post_init__()[source]
property nested_sampler: jaxns.nested_samplers.abc.AbstractNestedSampler[source]
Return type:

jaxns.nested_samplers.abc.AbstractNestedSampler

__call__(key, term_cond=None)[source]

Performs nested sampling with the given termination conditions.

Parameters:
Returns:

termination reason, state

Return type:

Tuple[jaxns.internals.types.IntArray, jaxns.nested_samplers.common.types.NestedSamplerState]

to_results(termination_reason, state, trim=True)[source]

Convert the state to results.

Note: Requires static context.

Parameters:
Returns:

results

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults

static trim_results(results)[source]

Trims the results to the number of samples taken. Requires static context.

Parameters:

results (jaxns.nested_samplers.common.types.NestedSamplerResults) – results to trim

Returns:

trimmed results

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults