evidence_maximisation

jaxns.experimental.evidence_maximisation

Module Contents

class EvidenceMaximisation[source]

Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it.

Parameters:
  • model – The model to train.

  • ns_kwargs – The keyword arguments to pass to the nested sampler. Needs at least max_samples.

  • max_num_epochs – The maximum number of epochs to run M-step for.

  • gtol – The parameter tolerance for the M-step. End when all parameters change by less than gtol.

  • log_Z_ftol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • log_Z_atol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • batch_size – The batch size to use for the M-step.

  • momentum – The momentum to use for the M-step.

  • termination_cond – The termination condition to use for the nested sampler.

  • verbose – Whether to print progress verbosely.

model: jaxns.Model[source]
ns_kwargs: Dict[str, Any][source]
max_num_epochs: int = 50[source]
gtol: float = 0.01[source]
log_Z_ftol: float = 1.0[source]
log_Z_atol: float = 0.0001[source]
batch_size: int = 128[source]
momentum: float = 0.0[source]
termination_cond: jaxns.internals.types.TerminationCondition | None[source]
verbose: bool = False[source]
__post_init__()[source]
e_step(key, params, p_bar)[source]

The E-step is just nested sampling.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (haiku.MutableParams) – The parameters to use.

  • p_bar (tqdm.tqdm) – progress bar

Returns:

The nested sampling results.

Return type:

jaxns.internals.types.NestedSamplerResults

m_step(key, params, ns_results, p_bar)[source]

The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (haiku.MutableParams) – The parameters to use.

  • ns_results (jaxns.internals.types.NestedSamplerResults) – The nested sampling results to use.

  • p_bar (tqdm.tqdm) – progress bar

Returns:

The updated parameters

Return type:

Tuple[haiku.MutableParams, Any]

train(num_steps=10, params=None)[source]

Train the model using EM for num_steps.

Parameters:
  • num_steps (int) – The number of steps to train for, or until convergence.

  • params (Optional[haiku.MutableParams]) – The initial parameters to use. If None, then the model’s params are used.

Returns:

The trained parameters.

Return type:

Tuple[jaxns.internals.types.NestedSamplerResults, haiku.MutableParams]