evidence_maximisation =============================== .. py:module:: jaxns.experimental.evidence_maximisation .. rubric:: :code:`jaxns.experimental.evidence_maximisation` .. rubric:: Module Contents .. py:class:: EvidenceMaximisation Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it using stochastic minibatching over samples from E-step. :param model: The model to train. :param ns_kwargs: The keyword arguments to pass to the nested sampler. Needs at least `max_samples`. :param max_num_epochs: The maximum number of epochs to run M-step for. :param gtol: The parameter tolerance for the M-step. End when all parameters change by less than gtol. :param log_Z_ftol: The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol). :param log_Z_atol: The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol). :param batch_size: The batch size to use for the M-step. :param termination_cond: The termination condition to use for the nested sampler. :param verbose: Whether to print progress verbosely. .. py:attribute:: model :type: jaxns.Model .. py:attribute:: ns_kwargs :type: Optional[Dict[str, Any]] :value: None .. py:attribute:: max_num_epochs :type: int :value: 50 .. py:attribute:: gtol :type: float :value: 0.01 .. py:attribute:: log_Z_ftol :type: float :value: 1.0 .. py:attribute:: log_Z_atol :type: float :value: 0.0001 .. py:attribute:: batch_size :type: Optional[int] :value: 128 .. py:attribute:: termination_cond :type: Optional[jaxns.nested_samplers.common.types.TerminationCondition] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:method:: e_step(key, params, desc) The E-step is just nested sampling. :param key: The random number generator key. :param params: The parameters to use. :param desc: progress bar desc :returns: The nested sampling results. .. py:method:: m_step(key, params, ns_results, desc) The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently. :param key: The random number generator key. :param params: The parameters to use. :param ns_results: The nested sampling results to use. :param desc: progress bar description :returns: The updated parameters .. py:method:: train(num_steps = 10, params = None) Train the model using EM for num_steps. :param num_steps: The number of steps to train for, or until convergence. :param params: The initial parameters to use. If None, then the model's params are used. :returns: The trained parameters.