Source code for jaxns.experimental.global_optimisation

import dataclasses
import io
from typing import NamedTuple, Optional, Union, TextIO, Tuple

import jax.numpy as jnp
import numpy as np
import pylab as plt
from jaxlib import xla_client

from jaxns.experimental.solvers.gauss_newton_cg import newton_cg_solver
from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.constraint_bijections import quick_unit, quick_unit_inverse
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import PRNGKey, IntArray, UType, FloatArray, LikelihoodInputType, XType
from jaxns.nested_samplers import ShardedStaticNestedSampler
from jaxns.nested_samplers.common.types import TerminationCondition, SampleCollection
from jaxns.samplers.abc import AbstractSampler
from jaxns.utils import _bit_mask

__all__ = [
    'GlobalOptimisationResults',
    'GlobalOptimisationTerminationCondition',
    'GlobalOptimisationState',
    'SimpleGlobalOptimisation'
]


[docs] class GlobalOptimisationState(NamedTuple):
[docs] key: PRNGKey
[docs] samples: SampleCollection
[docs] num_samples: IntArray
[docs] relative_spread: FloatArray
[docs] absolute_spread: FloatArray
[docs] num_likelihood_evaluations: IntArray
[docs] class GlobalOptimisationResults(NamedTuple):
[docs] U_solution: UType
[docs] X_solution: XType
[docs] solution: LikelihoodInputType
[docs] log_L_solution: FloatArray
[docs] log_L_progress: FloatArray
[docs] num_likelihood_evaluations: IntArray
[docs] num_samples: IntArray
[docs] termination_reason: IntArray
[docs] relative_spread: FloatArray
[docs] absolute_spread: FloatArray
[docs] class GlobalOptimisationTerminationCondition(NamedTuple):
[docs] max_likelihood_evaluations: Optional[IntArray] = None # jnp.asarray(jnp.iinfo(int_type).max, int_type)
[docs] log_likelihood_contour: Optional[FloatArray] = None # jnp.asarray(jnp.finfo(float_type).max, float_type)
[docs] rtol: Optional[FloatArray] = None # jnp.asarray(0., float_type)
[docs] atol: Optional[FloatArray] = None # jnp.asarray(0., float_type)
[docs] min_efficiency: Optional[FloatArray] = None # jnp.asarray(0., float_type)
def gradient_based_optimisation(model: BaseAbstractModel, init_U_point: UType) -> Tuple[UType, FloatArray, IntArray]: def loss(U_unconstrained: UType): U = quick_unit(U_unconstrained) return -model.log_prob_likelihood(U, allow_nan=False) solution, diagnostics = newton_cg_solver( loss, quick_unit_inverse(init_U_point) ) final_iter = jnp.max(diagnostics.iteration) final_obj = -diagnostics.f[final_iter] # Assuming ~4 function evaluations per CG iteration: # 1 for the initial function evaluation, 3 for the gradient and hvp. num_fun_eval = 4 * jnp.sum(diagnostics.cg_iters) return quick_unit(solution), final_obj, num_fun_eval @dataclasses.dataclass(eq=False)
[docs] class SimpleGlobalOptimisation: """ Simple global optimisation leveraging building blocks of nested sampling. """
[docs] sampler: AbstractSampler
[docs] num_search_chains: int
[docs] model: BaseAbstractModel
[docs] shell_frac: float = 0.5
[docs] devices: Optional[xla_client.Device] = None
[docs] verbose: bool = False
[docs] def __post_init__(self): if self.num_search_chains < 1: raise ValueError("num_search_chains must be >= 1.") self.num_search_chains = int(self.num_search_chains) self._nested_sampler = ShardedStaticNestedSampler( model=self.model, max_samples=self.num_search_chains * 10, init_efficiency_threshold=0.1, sampler=self.sampler, num_live_points=self.num_search_chains, shell_fraction=self.shell_frac, devices=self.devices, verbose=self.verbose )
def _gradient_descent(self, results: GlobalOptimisationResults) -> GlobalOptimisationResults: U_solution, log_L_solution, _num_likelihood_evals = gradient_based_optimisation(self.model, init_U_point=results.U_solution) X_solution = self.model.transform(U_solution) solution = self.model.prepare_input(U_solution) num_likelihood_evals = results.num_likelihood_evaluations + _num_likelihood_evals return results._replace( U_solution=U_solution, log_L_solution=log_L_solution, X_solution=X_solution, solution=solution, num_likelihood_evaluations=num_likelihood_evals ) def _to_results(self, termination_reason: IntArray, state: GlobalOptimisationState) -> GlobalOptimisationResults: """ Converts the final state of the global optimisation to results. Args: state: final state of the global optimisation Returns: results of the global optimisation """ is_sample_mask = jnp.arange(np.shape(state.samples.log_L)[0], dtype=mp_policy.index_dtype) < state.num_samples log_L_masked = jnp.where(is_sample_mask, state.samples.log_L, jnp.asarray(jnp.nan, mp_policy.measure_dtype)) best_idx = jnp.nanargmax(log_L_masked) U_solution = state.samples.U_samples[best_idx] X_solution = self.model.transform(U_solution) solution = self.model.prepare_input(U_solution) # The output of prior_model is solution log_L_solution = state.samples.log_L[best_idx] num_likelihood_evaluations = state.num_likelihood_evaluations log_L_progress = jnp.sort(log_L_masked) # Low to high likelihoods, nan's at the end. return GlobalOptimisationResults( U_solution=U_solution, X_solution=X_solution, solution=solution, log_L_solution=log_L_solution, log_L_progress=log_L_progress, num_likelihood_evaluations=num_likelihood_evaluations, num_samples=state.num_samples, relative_spread=state.relative_spread, absolute_spread=state.absolute_spread, termination_reason=termination_reason ) def _run(self, key: PRNGKey, term_cond: GlobalOptimisationTerminationCondition) -> Tuple[ IntArray, GlobalOptimisationState]: """ Runs the global optimisation. Args: key: PRNGKey term_cond: termination condition Returns: the termination reason and final state of the global optimisation """ termination_reason, termination_register, state = self._nested_sampler._run( key=key, term_cond=TerminationCondition( max_num_likelihood_evaluations=term_cond.max_likelihood_evaluations, log_L_contour=term_cond.log_likelihood_contour, efficiency_threshold=term_cond.min_efficiency, atol=term_cond.atol, rtol=term_cond.rtol, max_samples=None # Turn off max samples for global optimisation, wraps index ) ) go_state = GlobalOptimisationState( key=state.key, samples=state.sample_collection, num_samples=state.num_samples, absolute_spread=termination_register.absolute_spread, relative_spread=termination_register.relative_spread, num_likelihood_evaluations=termination_register.num_likelihood_evaluations ) return termination_reason, go_state
def plot_progress(results: GlobalOptimisationResults, save_file: Optional[str] = None): fig, ax = plt.subplots(1, 1, figsize=(8, 6)) start_sample_idx = max(0, results.num_samples - int(np.shape(results.log_L_progress)[0])) end_sample_idx = int(results.num_samples) ax.plot( np.arange(start_sample_idx, end_sample_idx), results.log_L_progress ) ax.set_title(f"Optimisation progress (samples {start_sample_idx} to {end_sample_idx})") ax.set_xlabel("Sample index") ax.set_ylabel(r"Objective ($\log L$)") if save_file is not None: plt.savefig(save_file) plt.show() def go_summary(results: GlobalOptimisationResults, f_obj: Optional[Union[str, TextIO]] = None): """ Gives a summary of the results of a global optimisation. Args: results (GlobalOptimisationResults): Nested sampler result """ main_s = [] def _print(s): print(s) main_s.append(s) def _round(v, uncert_v): v = float(v) uncert_v = float(uncert_v) try: sig_figs = -int("{:e}".format(uncert_v).split('e')[1]) + 1 return round(float(v), sig_figs) except: return float(v) def _print_termination_condition(_termination_reason: int): termination_bit_mask = _bit_mask(int(_termination_reason), width=11) for bit, condition in zip(termination_bit_mask, [ 'Reached max samples', 'Evidence uncertainty low enough', 'Small remaining evidence', 'Reached ESS', "Used max num likelihood evaluations", 'Likelihood contour reached', 'Sampler efficiency too low', 'All live-points are on a single plateau (sign of possible precision error)', 'relative spread of live points < rtol', 'absolute spread of live points < atol', 'no seed points left (consider decreasing shell_fraction)' ]): if bit == 1: _print(condition) _print("--------") _print("Termination Conditions:") if np.size(results.termination_reason) > 1: for replica_idx in range(np.size(results.termination_reason)): _print(f"Replica {replica_idx}:") _print_termination_condition(int(results.termination_reason[replica_idx])) else: _print_termination_condition(int(results.termination_reason)) _print("--------") _print(f"likelihood evals: {int(results.num_likelihood_evaluations):d}") _print(f"samples: {int(results.num_samples):d}") _print( f"likelihood evals / sample: {float(results.num_likelihood_evaluations / results.num_samples):.1f}" ) _print("--------") _print( f"max(log_L)={_round(results.log_L_solution, results.log_L_solution)}" ) _print( f"relative spread: {_round(results.relative_spread, results.relative_spread)}" ) _print( f"absolute spread: {_round(results.absolute_spread, results.absolute_spread)}" ) X_solution = results.X_solution for name in X_solution.keys(): # For shaped arrays, we want to print "{name}[{i0, i1, ..., in}]" for all valid indices. if np.size(X_solution[name]) == 0: continue _print("--------") shape = np.shape(X_solution[name]) num_dims = len(shape) is_shaped = num_dims > 0 var_name = f"{name}[{','.join(['#'] * num_dims)}]" if is_shaped else name _print( f"{var_name}: max(L) est." ) if is_shaped: indices = np.indices(shape).reshape((num_dims, -1)).T for inds in indices: _max_like_point = X_solution[name][tuple(inds)] _print(f"{name}[{','.join(str(i) for i in inds)}]: {_round(_max_like_point, 0.1 * _max_like_point)}") else: _max_like_point = X_solution[name] _print(f"{name}: {_round(_max_like_point, 0.1 * _max_like_point)}") _print("--------") if f_obj is not None: out = "\n".join(main_s) if isinstance(f_obj, str): with open(f_obj, 'w') as f: f.write(out) elif isinstance(f_obj, io.TextIOBase): f_obj.write(out) else: raise TypeError(f"Invalid f_obj: {type(f_obj)}")