experimental

jaxns.experimental

Submodules

Package Contents

class EvidenceMaximisation[source]

Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it.

Parameters:
  • model – The model to train.

  • ns_kwargs – The keyword arguments to pass to the nested sampler. Needs at least max_samples.

  • max_num_epochs – The maximum number of epochs to run M-step for.

  • gtol – The parameter tolerance for the M-step. End when all parameters change by less than gtol.

  • log_Z_ftol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • log_Z_atol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • batch_size – The batch size to use for the M-step.

  • momentum – The momentum to use for the M-step.

  • termination_cond – The termination condition to use for the nested sampler.

  • verbose – Whether to print progress verbosely.

model: jaxns.Model
ns_kwargs: Dict[str, Any]
max_num_epochs: int = 50
gtol: float = 0.01
log_Z_ftol: float = 1.0
log_Z_atol: float = 0.0001
batch_size: int = 128
momentum: float = 0.0
termination_cond: jaxns.internals.types.TerminationCondition | None
verbose: bool = False
__post_init__()[source]
e_step(key, params, p_bar)[source]

The E-step is just nested sampling.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (haiku.MutableParams) – The parameters to use.

  • p_bar (tqdm.tqdm) – progress bar

Returns:

The nested sampling results.

Return type:

jaxns.internals.types.NestedSamplerResults

m_step(key, params, ns_results, p_bar)[source]

The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (haiku.MutableParams) – The parameters to use.

  • ns_results (jaxns.internals.types.NestedSamplerResults) – The nested sampling results to use.

  • p_bar (tqdm.tqdm) – progress bar

Returns:

The updated parameters

Return type:

Tuple[haiku.MutableParams, Any]

train(num_steps=10, params=None)[source]

Train the model using EM for num_steps.

Parameters:
  • num_steps (int) – The number of steps to train for, or until convergence.

  • params (Optional[haiku.MutableParams]) – The initial parameters to use. If None, then the model’s params are used.

Returns:

The trained parameters.

Return type:

Tuple[jaxns.internals.types.NestedSamplerResults, haiku.MutableParams]

class GlobalOptimisationResults[source]

Bases: NamedTuple

U_solution: jaxns.internals.types.UType
X_solution: jaxns.internals.types.XType
solution: jaxns.internals.types.LikelihoodInputType
log_L_solution: jaxns.internals.types.FloatArray
num_likelihood_evaluations: jaxns.internals.types.IntArray
num_samples: jaxns.internals.types.IntArray
termination_reason: jaxns.internals.types.IntArray
relative_spread: jaxns.internals.types.FloatArray
absolute_spread: jaxns.internals.types.FloatArray
class GlobalOptimisationTerminationCondition[source]

Bases: NamedTuple

max_likelihood_evaluations: jaxns.internals.types.IntArray | int | None
log_likelihood_contour: jaxns.internals.types.FloatArray | float | None
rtol: jaxns.internals.types.FloatArray | float | None
atol: jaxns.internals.types.FloatArray | float | None
min_efficiency: jaxns.internals.types.FloatArray | float | None
__and__(other)[source]
__or__(other)[source]
class GlobalOptimisationState[source]

Bases: NamedTuple

key: jaxns.internals.types.PRNGKey
samples: jaxns.internals.types.Sample
num_likelihood_evaluations: jaxns.internals.types.IntArray
num_samples: jaxns.internals.types.IntArray
class SimpleGlobalOptimisation(sampler, num_search_chains, model, num_parallel_workers=1)[source]

Simple global optimisation leveraging building blocks of nested sampling.

Parameters: