Source code for jaxns.experimental.global_optimisation

import io
import logging
from typing import NamedTuple, Optional, Union, TextIO, Tuple, List

import jax.nn
import jax.numpy as jnp
import numpy as np
from jax import lax, random, pmap, tree_map
from jax._src.lax import parallel
from jax._src.scipy.special import logit
from jaxopt import NonlinearCG

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.maps import remove_chunk_dim
from jaxns.internals.types import PRNGKey, StaticStandardNestedSamplerState, BoolArray, StaticStandardSampleCollection, \
    int_type, Sample, IntArray, UType, FloatArray, LikelihoodInputType, XType
from jaxns.nested_sampler.standard_static import draw_uniform_samples, _inter_sync_shrinkage_process, \
    create_init_termination_register
from jaxns.samplers.bases import BaseAbstractSampler
from jaxns.utils import _bit_mask

logger = logging.getLogger('jaxns')

__all__ = [
    'GlobalOptimisationResults',
    'GlobalOptimisationTerminationCondition',
    'GlobalOptimisationState',
    'SimpleGlobalOptimisation'
]


[docs] class GlobalOptimisationState(NamedTuple):
[docs] key: PRNGKey
[docs] samples: Sample
[docs] num_likelihood_evaluations: IntArray
[docs] num_samples: IntArray
[docs] class GlobalOptimisationResults(NamedTuple):
[docs] U_solution: UType
[docs] X_solution: XType
[docs] solution: LikelihoodInputType
[docs] log_L_solution: FloatArray
[docs] num_likelihood_evaluations: IntArray
[docs] num_samples: IntArray
[docs] termination_reason: IntArray
[docs] relative_spread: FloatArray
[docs] absolute_spread: FloatArray
[docs] class GlobalOptimisationTerminationCondition(NamedTuple):
[docs] max_likelihood_evaluations: Optional[Union[IntArray, int]] = None # jnp.asarray(jnp.iinfo(int_type).max, int_type)
[docs] log_likelihood_contour: Optional[ Union[FloatArray, float]] = None # jnp.asarray(jnp.finfo(float_type).max, float_type)
[docs] rtol: Optional[Union[FloatArray, float]] = None # jnp.asarray(0., float_type)
[docs] atol: Optional[Union[FloatArray, float]] = None # jnp.asarray(0., float_type)
[docs] min_efficiency: Optional[Union[FloatArray, float]] = None # jnp.asarray(0., float_type)
[docs] def __and__(self, other): return TerminationConditionConjunction(conds=[self, other])
[docs] def __or__(self, other): return TerminationConditionDisjunction(conds=[self, other])
class TerminationConditionConjunction(NamedTuple): conds: List[Union[ 'TerminationConditionDisjunction', 'TerminationConditionConjunction', GlobalOptimisationTerminationCondition]] class TerminationConditionDisjunction(NamedTuple): conds: List[Union[ 'TerminationConditionDisjunction', TerminationConditionConjunction, GlobalOptimisationTerminationCondition]] def determine_termination(term_cond: GlobalOptimisationTerminationCondition, state: GlobalOptimisationState) -> Tuple[BoolArray, IntArray]: """ Determine if termination should happen. Termination Flags are bits: 0-bit -> 1: used maximum allowed number of likelihood evaluations 1-bit -> 2: reached goal log-likelihood contour 2-bit -> 4: relative spread of log-likelihood values below threshold 3-bit -> 8: absolute spread of log-likelihood values below threshold 4-bit -> 16: efficiency below threshold 5-bit -> 32: on a plateau (possibly local minimum, or due to numerical issues) Multiple flags are summed together Args: term_cond: termination condition state: global optimisation state Returns: boolean done signal, and termination reason """ termination_reason = jnp.asarray(0, int_type) done = jnp.asarray(False, jnp.bool_) def _set_done_bit(bit_done, bit_reason, done, termination_reason): if bit_done.size > 1: raise RuntimeError("bit_done must be a scalar.") done = jnp.bitwise_or(bit_done, done) termination_reason += jnp.where(bit_done, jnp.asarray(2 ** bit_reason, int_type), jnp.asarray(0, int_type)) return done, termination_reason if isinstance(term_cond, TerminationConditionConjunction): for c in term_cond.conds: _done, _reason = determine_termination(term_cond=c, state=state) done = jnp.bitwise_and(_done, done) termination_reason = jnp.bitwise_and(_reason, termination_reason) return done, termination_reason if isinstance(term_cond, TerminationConditionDisjunction): for c in term_cond.conds: _done, _reason = determine_termination(term_cond=c, state=state) done = jnp.bitwise_or(_done, done) termination_reason = jnp.bitwise_or(_reason, termination_reason) return done, termination_reason if term_cond.max_likelihood_evaluations is not None: # used all points reached_max_samples = state.num_likelihood_evaluations >= term_cond.max_likelihood_evaluations done, termination_reason = _set_done_bit(reached_max_samples, 0, done=done, termination_reason=termination_reason) if term_cond.log_likelihood_contour is not None: # reached goal log-likelihood contour reached_log_L_contour = jnp.max(state.samples.log_L) >= term_cond.log_likelihood_contour done, termination_reason = _set_done_bit(reached_log_L_contour, 1, done=done, termination_reason=termination_reason) if term_cond.rtol is not None: # relative spread of log-likelihood values below threshold max_log_L = jnp.max(state.samples.log_L) min_log_L = jnp.min(state.samples.log_L) diff_log_L = jnp.abs(max_log_L - min_log_L) reached_rtol = diff_log_L <= 0.5 * term_cond.rtol * jnp.abs(max_log_L + min_log_L) done, termination_reason = _set_done_bit(reached_rtol, 2, done=done, termination_reason=termination_reason) if term_cond.atol is not None: # absolute spread of log-likelihood values below threshold max_log_L = jnp.max(state.samples.log_L) min_log_L = jnp.min(state.samples.log_L) diff_log_L = jnp.abs(max_log_L - min_log_L) reached_atol = diff_log_L <= term_cond.atol done, termination_reason = _set_done_bit(reached_atol, 3, done=done, termination_reason=termination_reason) if term_cond.min_efficiency is not None: # efficiency below threshold efficiency = state.samples.log_L.shape[0] / jnp.sum(state.samples.num_likelihood_evaluations) reached_min_efficiency = efficiency <= term_cond.min_efficiency done, termination_reason = _set_done_bit(reached_min_efficiency, 4, done=done, termination_reason=termination_reason) # on plateau on_plateau = jnp.max(state.samples.log_L) == jnp.min(state.samples.log_L) done, termination_reason = _set_done_bit(on_plateau, 5, done=done, termination_reason=termination_reason) return done, termination_reason def gradient_based_optimisation(model: BaseAbstractModel, init_U_point: UType) -> Tuple[UType, FloatArray, IntArray]: def loss(U_unconstrained: UType): U = jax.nn.sigmoid(U_unconstrained) return -model.log_prob_likelihood(U, allow_nan=False) solver = NonlinearCG( fun=loss, jit=True, unroll=False, verbose=False ) results = solver.run(init_params=logit(init_U_point)) return jax.nn.sigmoid(results.params), -results.state.value, results.state.num_fun_eval def _single_thread_global_optimisation(init_state: GlobalOptimisationState, termination_cond: GlobalOptimisationTerminationCondition, sampler: BaseAbstractSampler) -> Tuple[GlobalOptimisationState, IntArray]: """ Runs a single thread of global optimisation. Sequentially samples until termination condition is met, replacing the worst sample with a new one in groups of num_samples. Args: init_state: initial state of the global optimisation termination_cond: termination condition sampler: sampler Returns: final state of the global optimisation and termination reason """ class CarryType(NamedTuple): state: GlobalOptimisationState def cond(carry: CarryType) -> BoolArray: done, _ = determine_termination(term_cond=termination_cond, state=carry.state) return jnp.bitwise_not(done) def body(carry: CarryType) -> CarryType: # Devices are independent, i.e. expect no communication between them in sampler. key, sample_key = random.split(carry.state.key, 2) num_samples = carry.state.samples.log_L.shape[0] fake_state = StaticStandardNestedSamplerState( key=sample_key, next_sample_idx=jnp.asarray(0, int_type), sample_collection=StaticStandardSampleCollection( sender_node_idx=jnp.zeros((num_samples,), int_type), log_L=carry.state.samples.log_L, U_samples=carry.state.samples.U_sample, num_likelihood_evaluations=carry.state.samples.num_likelihood_evaluations, phantom=jnp.zeros((num_samples,), jnp.bool_) ), front_idx=jnp.arange(num_samples, dtype=int_type) ) k = sampler.num_phantom() if k > 0: def _repeat(x): return jnp.repeat(x, (k + 1), axis=0) fake_state = fake_state._replace( sample_collection=tree_map(_repeat, fake_state.sample_collection) ) fake_state, fake_termination_register = _inter_sync_shrinkage_process( init_state=fake_state, sampler=sampler, num_samples=num_samples * (1 + k), init_termination_register=create_init_termination_register() ) num_likelihood_evaluations = carry.state.num_likelihood_evaluations + jnp.sum( fake_state.sample_collection.num_likelihood_evaluations) if k > 0: # Choose the maximum likelihood sample from the k+1 samples (requires reshape first to unstack) choose_idx = jnp.argmax( jnp.reshape(fake_state.sample_collection.log_L, ((k + 1), num_samples)), axis=0 ) def _select(x): x = jnp.reshape(x, ((k + 1), num_samples) + x.shape[1:]) # [k+1, N, ...] return x[choose_idx, jnp.arange(num_samples)] # [N, ...] fake_state = fake_state._replace( sample_collection=tree_map( _select, fake_state.sample_collection ) ) samples = Sample( U_sample=fake_state.sample_collection.U_samples, log_L_constraint=fake_state.sample_collection.log_L, log_L=fake_state.sample_collection.log_L, num_likelihood_evaluations=fake_state.sample_collection.num_likelihood_evaluations ) state = GlobalOptimisationState( key=key, samples=samples, num_likelihood_evaluations=num_likelihood_evaluations, num_samples=carry.state.num_samples + jnp.asarray(num_samples, int_type) ) return CarryType(state=state) init_carry_state = CarryType( state=init_state ) carry_state: CarryType = lax.while_loop( cond_fun=cond, body_fun=body, init_val=init_carry_state ) _, termination_reason = determine_termination(term_cond=termination_cond, state=carry_state.state) return carry_state.state, termination_reason def create_init_state(key: PRNGKey, num_search_chains: int, model: BaseAbstractModel) -> GlobalOptimisationState: """ Creates the initial state of the global optimisation. Args: key: PRNGKey num_search_chains: number of search chains model: model Returns: initial state of the global optimisation """ key, sample_key = random.split(key, 2) init_samples = draw_uniform_samples(key=sample_key, num_live_points=num_search_chains, model=model) return GlobalOptimisationState( key=key, samples=init_samples, num_likelihood_evaluations=jnp.sum(init_samples.num_likelihood_evaluations), num_samples=jnp.asarray(num_search_chains, int_type) )
[docs] class SimpleGlobalOptimisation: """ Simple global optimisation leveraging building blocks of nested sampling. """ def __init__(self, sampler: BaseAbstractSampler, num_search_chains: int, model: BaseAbstractModel, num_parallel_workers: int = 1): self.sampler = sampler if num_search_chains < 1: raise ValueError("num_search_chains must be >= 1.") self.num_search_chains = int(num_search_chains) self.num_parallel_workers = int(num_parallel_workers) if self.num_parallel_workers > 1: logger.info(f"Using {self.num_parallel_workers} parallel workers, each running identical samplers.") self.model = model self.num_search_chains = num_search_chains def _gradient_descent(self, results: GlobalOptimisationResults) -> GlobalOptimisationResults: U_solution, log_L_solution, _num_likelihood_evals = gradient_based_optimisation(self.model, init_U_point=results.U_solution) X_solution = self.model.transform(U_solution) solution = self.model.prepare_input(U_solution) num_likelihood_evals = results.num_likelihood_evaluations + _num_likelihood_evals return results._replace( U_solution=U_solution, log_L_solution=log_L_solution, X_solution=X_solution, solution=solution, num_likelihood_evaluations=num_likelihood_evals ) def _to_results(self, termination_reason: IntArray, state: GlobalOptimisationState) -> GlobalOptimisationResults: """ Converts the final state of the global optimisation to results. Args: state: final state of the global optimisation Returns: results of the global optimisation """ best_idx = jnp.argmax(state.samples.log_L) U_solution = state.samples.U_sample[best_idx] X_solution = self.model.transform(U_solution) solution = self.model.prepare_input(U_solution) max_log_L = state.samples.log_L[best_idx] min_log_L = jnp.min(state.samples.log_L) relative_spread = 2. * jnp.abs(max_log_L - min_log_L) / jnp.abs(max_log_L + min_log_L) absolute_spread = jnp.abs(max_log_L - min_log_L) return GlobalOptimisationResults( U_solution=state.samples.U_sample[best_idx], X_solution=X_solution, solution=solution, log_L_solution=state.samples.log_L[best_idx], num_likelihood_evaluations=state.num_likelihood_evaluations, num_samples=state.num_samples, relative_spread=relative_spread, absolute_spread=absolute_spread, termination_reason=termination_reason ) def _run(self, key: PRNGKey, term_cond: GlobalOptimisationTerminationCondition) -> Tuple[ IntArray, GlobalOptimisationState]: """ Runs the global optimisation. Args: key: PRNGKey term_cond: termination condition Returns: the termination reason and final state of the global optimisation """ def replica(key: PRNGKey) -> Tuple[GlobalOptimisationState, IntArray]: state = create_init_state( key=key, num_search_chains=self.num_search_chains, model=self.model ) # Continue sampling with provided sampler until user-defined termination condition is met. state, termination_reason = _single_thread_global_optimisation( init_state=state, termination_cond=term_cond, sampler=self.sampler ) if self.num_parallel_workers > 1: target_log_L_contour = jnp.max( parallel.all_gather(jnp.max(state.samples.log_L), 'i') ) state, termination_reason = _single_thread_global_optimisation( init_state=state, termination_cond=term_cond._replace(log_likelihood_contour=target_log_L_contour), sampler=self.sampler ) return state, termination_reason if self.num_parallel_workers > 1: parallel_ns = pmap(replica, axis_name='i') keys = random.split(key, self.num_parallel_workers) batched_state, termination_reason = parallel_ns(keys) state = remove_chunk_dim(batched_state) state = state._replace(key=state.key[0]) else: state, termination_reason = replica(key) return termination_reason, state
def summary(results: GlobalOptimisationResults, f_obj: Optional[Union[str, TextIO]] = None): """ Gives a summary of the results of a global optimisation. Args: results (GlobalOptimisationResults): Nested sampler result """ main_s = [] def _print(s): print(s) main_s.append(s) def _round(v, uncert_v): v = float(v) uncert_v = float(uncert_v) try: sig_figs = -int("{:e}".format(uncert_v).split('e')[1]) + 1 return round(float(v), sig_figs) except: return float(v) def _print_termination_condition(_termination_reason: int): termination_bit_mask = _bit_mask(int(_termination_reason), width=8) # 0-bit -> 1: used maximum allowed number of likelihood evaluations # 1-bit -> 2: reached goal log-likelihood contour # 2-bit -> 4: relative spread of log-likelihood values below threshold # 3-bit -> 8: absolute spread of log-likelihood values below threshold # 4-bit -> 16: efficiency below threshold # 5-bit -> 32: on a plateau (possibly local minimum, or due to numerical issues) for bit, condition in zip(termination_bit_mask, [ 'Reached max num likelihood evaluations', 'Reached goal log-likelihood contour', 'Small relative spread of log-likelihood values', 'Small absolute spread of log-likelihood values', 'Sampler efficiency too low', 'On plateau (possibly local minimum, or due to numerical issues)' ]): if bit == 1: _print(condition) _print("--------") _print("Termination Conditions:") if np.size(results.termination_reason) > 1: for replica_idx in range(np.size(results.termination_reason)): _print(f"Replica {replica_idx}:") _print_termination_condition(int(results.termination_reason[replica_idx])) else: _print_termination_condition(int(results.termination_reason)) _print("--------") _print(f"likelihood evals: {int(results.num_likelihood_evaluations):d}") _print(f"samples: {int(results.num_samples):d}") _print( f"likelihood evals / sample: {float(results.num_likelihood_evaluations / results.num_samples):.1f}" ) _print("--------") _print( f"max(log_L)={_round(results.log_L_solution, results.log_L_solution)}" ) _print( f"relative spread: {_round(results.relative_spread, results.relative_spread)}" ) _print( f"absolute spread: {_round(results.absolute_spread, results.absolute_spread)}" ) X_solution = results.X_solution for name in X_solution.keys(): _sample = X_solution[name].reshape((-1,)) ndims = _sample.size _print("--------") var_name = name if ndims == 1 else "{}[#]".format(name) _print( f"{var_name}: max(L) est." ) for dim in range(ndims): _max_like_point = _sample[dim] # two sig-figs based on uncert sig_figs = -int("{:e}".format(_max_like_point * 0.1).split('e')[1]) + 1 def _round(ar): return round(float(ar), sig_figs) _print("{}: {}".format( name if ndims == 1 else "{}[{}]".format(name, dim), _round(_max_like_point) )) _print("--------") if f_obj is not None: out = "\n".join(main_s) if isinstance(f_obj, str): with open(f_obj, 'w') as f: f.write(out) elif isinstance(f_obj, io.TextIOBase): f_obj.write(out) else: raise TypeError(f"Invalid f_obj: {type(f_obj)}")