public
jaxns.experimental.public
Module Contents
- class DefaultGlobalOptimisation(model, num_search_chains=None, num_parallel_workers=1, s=None, k=None, gradient_slice=False)[source]
Default global optimisation class.
A global optimisation class that uses 1-dimensional slice sampler for the sampling step and decent default values.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – a model to perform global optimisation on
num_search_chains (Optional[int]) – number of search chains to use. Defaults to 20 * D.
num_parallel_workers (int) – number of parallel workers to use. Defaults to 1. Experimental feature. If set creates a pool of identical workers and runs them in parallel.
s (Optional[int]) – number of slices to use per dimension. Defaults to 1.
k (Optional[int]) – number of phantom samples to use. Defaults to 0.
gradient_slice (bool) – if true use gradient information to improve.
- __call__(key, term_cond=None, finetune=False)[source]
Runs the global optimisation.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
term_cond (Optional[jaxns.experimental.GlobalOptimisationTerminationCondition]) – termination condition
finetune (bool) – whether to use gradient-based fine-tune. Default False because not all models have gradients.
- Returns:
results of the global optimisation
- Return type:
- summary(results, f_obj=None)[source]
- Parameters:
results (jaxns.experimental.GlobalOptimisationResults) –
f_obj (Optional[Union[str, TextIO]]) –