experimental

jaxns.experimental

Subpackages

Submodules

Package Contents

class EvidenceMaximisation[source]

Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it using stochastic minibatching over samples from E-step.

Parameters:
  • model – The model to train.

  • ns_kwargs – The keyword arguments to pass to the nested sampler. Needs at least max_samples.

  • max_num_epochs – The maximum number of epochs to run M-step for.

  • gtol – The parameter tolerance for the M-step. End when all parameters change by less than gtol.

  • log_Z_ftol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • log_Z_atol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • batch_size – The batch size to use for the M-step.

  • termination_cond – The termination condition to use for the nested sampler.

  • verbose – Whether to print progress verbosely.

model: jaxns.Model
ns_kwargs: Dict[str, Any] | None = None
max_num_epochs: int = 50
gtol: float = 0.01
log_Z_ftol: float = 1.0
log_Z_atol: float = 0.0001
batch_size: int | None = 128
termination_cond: jaxns.nested_samplers.common.types.TerminationCondition | None = None
verbose: bool = False
__post_init__()[source]
e_step(key, params, desc)[source]

The E-step is just nested sampling.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (jaxns.framework.context.MutableParams) – The parameters to use.

  • desc – progress bar desc

Returns:

The nested sampling results.

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults

m_step(key, params, ns_results, desc)[source]

The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (jaxns.framework.context.MutableParams) – The parameters to use.

  • ns_results (jaxns.nested_samplers.common.types.NestedSamplerResults) – The nested sampling results to use.

  • desc (str) – progress bar description

Returns:

The updated parameters

Return type:

Tuple[jaxns.framework.context.MutableParams, Any]

train(num_steps=10, params=None)[source]

Train the model using EM for num_steps.

Parameters:
  • num_steps (int) – The number of steps to train for, or until convergence.

  • params (Optional[jaxns.framework.context.MutableParams]) – The initial parameters to use. If None, then the model’s params are used.

Returns:

The trained parameters.

Return type:

Tuple[jaxns.nested_samplers.common.types.NestedSamplerResults, jaxns.framework.context.MutableParams]

class GlobalOptimisationResults[source]

Bases: NamedTuple

U_solution: jaxns.internals.types.UType
X_solution: jaxns.internals.types.XType
solution: jaxns.internals.types.LikelihoodInputType
log_L_solution: jaxns.internals.types.FloatArray
log_L_progress: jaxns.internals.types.FloatArray
num_likelihood_evaluations: jaxns.internals.types.IntArray
num_samples: jaxns.internals.types.IntArray
termination_reason: jaxns.internals.types.IntArray
relative_spread: jaxns.internals.types.FloatArray
absolute_spread: jaxns.internals.types.FloatArray
class GlobalOptimisationTerminationCondition[source]

Bases: NamedTuple

max_likelihood_evaluations: jaxns.internals.types.IntArray | None = None
log_likelihood_contour: jaxns.internals.types.FloatArray | None = None
rtol: jaxns.internals.types.FloatArray | None = None
atol: jaxns.internals.types.FloatArray | None = None
min_efficiency: jaxns.internals.types.FloatArray | None = None
class GlobalOptimisationState[source]

Bases: NamedTuple

key: jaxns.internals.types.PRNGKey
samples: jaxns.nested_samplers.common.types.SampleCollection
num_samples: jaxns.internals.types.IntArray
relative_spread: jaxns.internals.types.FloatArray
absolute_spread: jaxns.internals.types.FloatArray
num_likelihood_evaluations: jaxns.internals.types.IntArray
class SimpleGlobalOptimisation[source]

Simple global optimisation leveraging building blocks of nested sampling.

sampler: jaxns.samplers.abc.AbstractSampler
num_search_chains: int
model: jaxns.framework.bases.BaseAbstractModel
shell_frac: float = 0.5
devices: jaxlib.xla_client.Device | None = None
verbose: bool = False
__post_init__()[source]
class GlobalOptimisation[source]

A global optimiser using nested sampling as the core algorithm. Can easily globally optimise complex models, with curving degeneracies and multimodal structure. Highly parallelisable. Recommended to use gradient information by setting gradient_slice=True.

Note, the log-likelihood over the model is maximised NOT the posterior. The prior acts as the search space prior, by constraining the search space and giving search preference to regions of high prior probability. Thus, the prior should encode your prior belief about where you think the global maximum is located.

Parameters:
  • model – a model to perform global optimisation on over the sample space.

  • num_search_chains – number of search chains to use.

  • s – number of slices to use per dimension.

  • k – number of phantom samples to use.

  • gradient_slice – if true use gradient information to improve. Default True.

  • shell_frac – fraction of the shell to discard in parallel.

  • devices – devices to use for parallel sharded computation. Default all available devices.

  • verbose – whether to print verbose output. Default False.

model: jaxns.framework.bases.BaseAbstractModel
num_search_chains: int | None = None
s: int | None = None
k: int | None = None
gradient_slice: bool = True
shell_frac: float | None = None
devices: jaxlib.xla_client.Device | None = None
verbose: bool = False
__post_init__()[source]
__call__(key, term_cond=None, finetune=False)[source]

Runs the global optimisation.

Parameters:
Returns:

results of the global optimisation

Return type:

jaxns.experimental.GlobalOptimisationResults

DefaultGlobalOptimisation[source]