experimental ====================== .. py:module:: jaxns.experimental .. rubric:: :code:`jaxns.experimental` .. rubric:: Subpackages .. toctree:: :titlesonly: :maxdepth: 1 solvers/index.rst .. rubric:: Submodules .. toctree:: :titlesonly: :maxdepth: 1 evidence_maximisation/index.rst global_optimisation/index.rst public/index.rst .. rubric:: Package Contents .. py:class:: EvidenceMaximisation Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it using stochastic minibatching over samples from E-step. :param model: The model to train. :param ns_kwargs: The keyword arguments to pass to the nested sampler. Needs at least `max_samples`. :param max_num_epochs: The maximum number of epochs to run M-step for. :param gtol: The parameter tolerance for the M-step. End when all parameters change by less than gtol. :param log_Z_ftol: The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol). :param log_Z_atol: The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol). :param batch_size: The batch size to use for the M-step. :param termination_cond: The termination condition to use for the nested sampler. :param verbose: Whether to print progress verbosely. .. py:attribute:: model :type: jaxns.Model .. py:attribute:: ns_kwargs :type: Optional[Dict[str, Any]] :value: None .. py:attribute:: max_num_epochs :type: int :value: 50 .. py:attribute:: gtol :type: float :value: 0.01 .. py:attribute:: log_Z_ftol :type: float :value: 1.0 .. py:attribute:: log_Z_atol :type: float :value: 0.0001 .. py:attribute:: batch_size :type: Optional[int] :value: 128 .. py:attribute:: termination_cond :type: Optional[jaxns.nested_samplers.common.types.TerminationCondition] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:method:: e_step(key, params, desc) The E-step is just nested sampling. :param key: The random number generator key. :param params: The parameters to use. :param desc: progress bar desc :returns: The nested sampling results. .. py:method:: m_step(key, params, ns_results, desc) The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently. :param key: The random number generator key. :param params: The parameters to use. :param ns_results: The nested sampling results to use. :param desc: progress bar description :returns: The updated parameters .. py:method:: train(num_steps = 10, params = None) Train the model using EM for num_steps. :param num_steps: The number of steps to train for, or until convergence. :param params: The initial parameters to use. If None, then the model's params are used. :returns: The trained parameters. .. py:class:: GlobalOptimisationResults Bases: :py:obj:`NamedTuple` .. py:attribute:: U_solution :type: jaxns.internals.types.UType .. py:attribute:: X_solution :type: jaxns.internals.types.XType .. py:attribute:: solution :type: jaxns.internals.types.LikelihoodInputType .. py:attribute:: log_L_solution :type: jaxns.internals.types.FloatArray .. py:attribute:: log_L_progress :type: jaxns.internals.types.FloatArray .. py:attribute:: num_likelihood_evaluations :type: jaxns.internals.types.IntArray .. py:attribute:: num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: termination_reason :type: jaxns.internals.types.IntArray .. py:attribute:: relative_spread :type: jaxns.internals.types.FloatArray .. py:attribute:: absolute_spread :type: jaxns.internals.types.FloatArray .. py:class:: GlobalOptimisationTerminationCondition Bases: :py:obj:`NamedTuple` .. py:attribute:: max_likelihood_evaluations :type: Optional[jaxns.internals.types.IntArray] :value: None .. py:attribute:: log_likelihood_contour :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: rtol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: atol :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:attribute:: min_efficiency :type: Optional[jaxns.internals.types.FloatArray] :value: None .. py:class:: GlobalOptimisationState Bases: :py:obj:`NamedTuple` .. py:attribute:: key :type: jaxns.internals.types.PRNGKey .. py:attribute:: samples :type: jaxns.nested_samplers.common.types.SampleCollection .. py:attribute:: num_samples :type: jaxns.internals.types.IntArray .. py:attribute:: relative_spread :type: jaxns.internals.types.FloatArray .. py:attribute:: absolute_spread :type: jaxns.internals.types.FloatArray .. py:attribute:: num_likelihood_evaluations :type: jaxns.internals.types.IntArray .. py:class:: SimpleGlobalOptimisation Simple global optimisation leveraging building blocks of nested sampling. .. py:attribute:: sampler :type: jaxns.samplers.abc.AbstractSampler .. py:attribute:: num_search_chains :type: int .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: shell_frac :type: float :value: 0.5 .. py:attribute:: devices :type: Optional[jaxlib.xla_client.Device] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:class:: GlobalOptimisation A global optimiser using nested sampling as the core algorithm. Can easily globally optimise complex models, with curving degeneracies and multimodal structure. Highly parallelisable. Recommended to use gradient information by setting gradient_slice=True. Note, the log-likelihood over the model is maximised NOT the posterior. The prior acts as the search space prior, by constraining the search space and giving search preference to regions of high prior probability. Thus, the prior should encode your prior belief about where you think the global maximum is located. :param model: a model to perform global optimisation on over the sample space. :param num_search_chains: number of search chains to use. :param s: number of slices to use per dimension. :param k: number of phantom samples to use. :param gradient_slice: if true use gradient information to improve. Default True. :param shell_frac: fraction of the shell to discard in parallel. :param devices: devices to use for parallel sharded computation. Default all available devices. :param verbose: whether to print verbose output. Default False. .. py:attribute:: model :type: jaxns.framework.bases.BaseAbstractModel .. py:attribute:: num_search_chains :type: Optional[int] :value: None .. py:attribute:: s :type: Optional[int] :value: None .. py:attribute:: k :type: Optional[int] :value: None .. py:attribute:: gradient_slice :type: bool :value: True .. py:attribute:: shell_frac :type: Optional[float] :value: None .. py:attribute:: devices :type: Optional[jaxlib.xla_client.Device] :value: None .. py:attribute:: verbose :type: bool :value: False .. py:method:: __post_init__() .. py:method:: __call__(key, term_cond = None, finetune = False) Runs the global optimisation. :param key: PRNGKey :param term_cond: termination condition :param finetune: whether to use gradient-based fine-tune. Default False because not all models have gradients. :returns: results of the global optimisation .. py:data:: DefaultGlobalOptimisation