from typing import Optional, TextIO, Union
from jaxns.experimental import SimpleGlobalOptimisation, GlobalOptimisationTerminationCondition, \
GlobalOptimisationResults
from jaxns.experimental.global_optimisation import summary
from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.types import PRNGKey
from jaxns.samplers import UniDimSliceSampler
__all__ = [
'DefaultGlobalOptimisation'
]
[docs]
class DefaultGlobalOptimisation:
"""
Default global optimisation class.
"""
def __init__(self, model: BaseAbstractModel,
num_search_chains: Optional[int] = None,
num_parallel_workers: int = 1,
s: Optional[int] = None,
k: Optional[int] = None,
gradient_slice: bool = False
):
"""
A global optimisation class that uses 1-dimensional slice sampler for the sampling step and decent default
values.
Args:
model: a model to perform global optimisation on
num_search_chains: number of search chains to use. Defaults to 20 * D.
num_parallel_workers: number of parallel workers to use. Defaults to 1. Experimental feature.
If set creates a pool of identical workers and runs them in parallel.
s: number of slices to use per dimension. Defaults to 1.
k: number of phantom samples to use. Defaults to 0.
gradient_slice: if true use gradient information to improve.
"""
if num_search_chains is None:
num_search_chains = model.U_ndims * 20
if s is None:
s = 1
if k is None:
k = 0
sampler = UniDimSliceSampler(
model=model,
num_slices=model.U_ndims * int(s),
num_phantom_save=int(k),
midpoint_shrink=True,
perfect=True,
gradient_slice=gradient_slice
)
self._global_optimiser = SimpleGlobalOptimisation(
sampler=sampler,
num_search_chains=int(num_search_chains),
model=model,
num_parallel_workers=num_parallel_workers
)
[docs]
def __call__(self, key: PRNGKey,
term_cond: Optional[GlobalOptimisationTerminationCondition] = None,
finetune: bool = False) -> GlobalOptimisationResults:
"""
Runs the global optimisation.
Args:
key: PRNGKey
term_cond: termination condition
finetune: whether to use gradient-based fine-tune. Default False because not all models have gradients.
Returns:
results of the global optimisation
"""
if term_cond is None:
term_cond = GlobalOptimisationTerminationCondition(
min_efficiency=3e-2
)
termination_reason, state = self._global_optimiser._run(key, term_cond)
results = self._global_optimiser._to_results(termination_reason, state)
if finetune:
results = self._global_optimiser._gradient_descent(results=results)
return results
[docs]
def summary(self, results: GlobalOptimisationResults, f_obj: Optional[Union[str, TextIO]] = None):
summary(results, f_obj=f_obj)