evidence_maximisation

jaxns.experimental.evidence_maximisation

Module Contents

class EvidenceMaximisation[source]

Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it using stochastic minibatching over samples from E-step.

Parameters:
  • model – The model to train.

  • ns_kwargs – The keyword arguments to pass to the nested sampler. Needs at least max_samples.

  • max_num_epochs – The maximum number of epochs to run M-step for.

  • gtol – The parameter tolerance for the M-step. End when all parameters change by less than gtol.

  • log_Z_ftol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • log_Z_atol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).

  • batch_size – The batch size to use for the M-step.

  • termination_cond – The termination condition to use for the nested sampler.

  • verbose – Whether to print progress verbosely.

model: jaxns.Model[source]
ns_kwargs: Dict[str, Any] | None = None[source]
max_num_epochs: int = 50[source]
gtol: float = 0.01[source]
log_Z_ftol: float = 1.0[source]
log_Z_atol: float = 0.0001[source]
batch_size: int | None = 128[source]
termination_cond: jaxns.nested_samplers.common.types.TerminationCondition | None = None[source]
verbose: bool = False[source]
__post_init__()[source]
e_step(key, params, desc)[source]

The E-step is just nested sampling.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (jaxns.framework.context.MutableParams) – The parameters to use.

  • desc – progress bar desc

Returns:

The nested sampling results.

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults

m_step(key, params, ns_results, desc)[source]

The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – The random number generator key.

  • params (jaxns.framework.context.MutableParams) – The parameters to use.

  • ns_results (jaxns.nested_samplers.common.types.NestedSamplerResults) – The nested sampling results to use.

  • desc (str) – progress bar description

Returns:

The updated parameters

Return type:

Tuple[jaxns.framework.context.MutableParams, Any]

train(num_steps=10, params=None)[source]

Train the model using EM for num_steps.

Parameters:
  • num_steps (int) – The number of steps to train for, or until convergence.

  • params (Optional[jaxns.framework.context.MutableParams]) – The initial parameters to use. If None, then the model’s params are used.

Returns:

The trained parameters.

Return type:

Tuple[jaxns.nested_samplers.common.types.NestedSamplerResults, jaxns.framework.context.MutableParams]