evidence_maximisation
jaxns.experimental.evidence_maximisation
Module Contents
- class EvidenceMaximisation[source]
Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it.
- Parameters:
model – The model to train.
ns_kwargs – The keyword arguments to pass to the nested sampler. Needs at least max_samples.
max_num_epochs – The maximum number of epochs to run M-step for.
gtol – The parameter tolerance for the M-step. End when all parameters change by less than gtol.
log_Z_ftol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).
log_Z_atol – The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).
batch_size – The batch size to use for the M-step.
momentum – The momentum to use for the M-step.
termination_cond – The termination condition to use for the nested sampler.
verbose – Whether to print progress verbosely.
- model: jaxns.Model[source]
- termination_cond: jaxns.internals.types.TerminationCondition | None[source]
- e_step(key, params, p_bar)[source]
The E-step is just nested sampling.
- Parameters:
key (jaxns.internals.types.PRNGKey) – The random number generator key.
params (haiku.MutableParams) – The parameters to use.
p_bar (tqdm.tqdm) – progress bar
- Returns:
The nested sampling results.
- Return type:
- m_step(key, params, ns_results, p_bar)[source]
The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently.
- Parameters:
key (jaxns.internals.types.PRNGKey) – The random number generator key.
params (haiku.MutableParams) – The parameters to use.
ns_results (jaxns.internals.types.NestedSamplerResults) – The nested sampling results to use.
p_bar (tqdm.tqdm) – progress bar
- Returns:
The updated parameters
- Return type:
Tuple[haiku.MutableParams, Any]
- train(num_steps=10, params=None)[source]
Train the model using EM for num_steps.
- Parameters:
num_steps (int) – The number of steps to train for, or until convergence.
params (Optional[haiku.MutableParams]) – The initial parameters to use. If None, then the model’s params are used.
- Returns:
The trained parameters.
- Return type:
Tuple[jaxns.internals.types.NestedSamplerResults, haiku.MutableParams]