Source code for jaxns.experimental.evidence_maximisation

import dataclasses
import time
from functools import partial
from typing import Tuple, Dict, Any, Optional, NamedTuple

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp, random

from jaxns import NestedSampler, Model
from jaxns.experimental.solvers.gauss_newton_cg import newton_cg_solver
from jaxns.framework.context import MutableParams
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.log_semiring import LogSpace
from jaxns.internals.logging import logger
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import IntArray, PRNGKey
from jaxns.nested_samplers.common.types import TerminationCondition, NestedSamplerResults, \
    StaticStandardNestedSamplerState

__all__ = [
    'EvidenceMaximisation'
]

tfpd = tfp.distributions
tfpk = tfp.math.psd_kernels


class MStepData(NamedTuple):
    U_samples: jax.Array
    log_weights: jax.Array
    # log_dp_mean: jax.Array
    # log_L_samples: jax.Array
    # log_Z_mean: jax.Array


def next_power_2(x: int) -> int:
    """
    Next largest power of 2.

    Args:
        x:  int

    Returns:
        next largest n**2
    """
    return int(2 ** np.ceil(np.log2(x)))


@dataclasses.dataclass(eq=False)
[docs] class EvidenceMaximisation: """ Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it using stochastic minibatching over samples from E-step. Args: model: The model to train. ns_kwargs: The keyword arguments to pass to the nested sampler. Needs at least `max_samples`. max_num_epochs: The maximum number of epochs to run M-step for. gtol: The parameter tolerance for the M-step. End when all parameters change by less than gtol. log_Z_ftol, log_Z_atol: The tolerances for the change in the evidence as function of log_Z_uncert. Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol). batch_size: The batch size to use for the M-step. termination_cond: The termination condition to use for the nested sampler. verbose: Whether to print progress verbosely. """
[docs] model: Model
[docs] ns_kwargs: Optional[Dict[str, Any]] = None
[docs] max_num_epochs: int = 50
[docs] gtol: float = 1e-2
[docs] log_Z_ftol: float = 1.
[docs] log_Z_atol: float = 1e-4
[docs] batch_size: Optional[int] = 128
[docs] termination_cond: Optional[TerminationCondition] = None
[docs] verbose: bool = False
[docs] def __post_init__(self): if self.ns_kwargs is None: self.ns_kwargs = {} self._e_step = self._create_e_step() self._m_step = self._create_m_step()
def _create_e_step(self): """ Create a compiled function that runs nested sampling and returns trimmed results. Returns: A compiled function that runs nested sampling and returns trimmed results. """ def _ns_solve(params: MutableParams, key: random.PRNGKey) -> Tuple[ IntArray, StaticStandardNestedSamplerState]: model = self.model(params=params) ns = NestedSampler(model=model, **self.ns_kwargs) termination_reason, state = ns(key, self.termination_cond) return termination_reason, state # Ahead of time compile the function t0 = time.time() ns_solve_compiled = jax.jit(_ns_solve).lower(self.model.params, random.PRNGKey(42)).compile() if self.verbose: logger.info(f"E-step compilation time: {time.time() - t0:.2f}s") ns = NestedSampler(model=self.model(params=self.model.params), **self.ns_kwargs) def _e_step(key: PRNGKey, params: MutableParams, desc: str) -> NestedSamplerResults: print(f"Running E-step... {desc}") termination_reason, state = ns_solve_compiled(params, key) # Trim results return ns.to_results(termination_reason=termination_reason, state=state, trim=True) return _e_step
[docs] def e_step(self, key: PRNGKey, params: MutableParams, desc) -> NestedSamplerResults: """ The E-step is just nested sampling. Args: key: The random number generator key. params: The parameters to use. desc: progress bar desc Returns: The nested sampling results. """ # The E-step is just nested sampling return self._e_step(key, params, desc)
def _m_step_iterator(self, key: PRNGKey, data: MStepData): num_samples = int(data.U_samples.shape[0]) permutation = jax.random.permutation(key, num_samples) if self.batch_size is None: batch_size = num_samples else: batch_size = self.batch_size num_batches = num_samples // batch_size if num_batches == 0: raise RuntimeError( f"Batch size {batch_size} is too large for number of samples, number of samples {num_samples}.") for i in range(num_batches): perm = permutation[i * batch_size:(i + 1) * batch_size] batch = MStepData( U_samples=data.U_samples[perm], log_weights=data.log_weights[perm] ) yield batch def _create_m_step(self): def log_evidence(params: MutableParams, data: MStepData): # Compute the log evidence model = self.model(params=params) def op(log_Z, data): log_dZ = model.forward(data.U_samples) + data.log_weights return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, mp_policy.measure_dtype), xs=data) return log_Z def loss(params: MutableParams, data: MStepData): log_Z = log_evidence(params, data) obj = -log_Z if self.verbose: jax.debug.print("log_Z={log_Z}", log_Z=log_Z) return obj @partial(jax.jit) def _m_step(key: PRNGKey, params: MutableParams, data: MStepData) -> Tuple[MutableParams, Any]: """ The M-step is just evidence maximisation. Args: params: The parameters to use. data: The data to use. Returns: The updated parameters and the negative log evidence. """ params, diagnostics = newton_cg_solver(loss, params, args=(data,)) i = jnp.max(diagnostics.iteration) return params, -diagnostics.f[i] return _m_step
[docs] def m_step(self, key: PRNGKey, params: MutableParams, ns_results: NestedSamplerResults, desc: str) -> Tuple[ MutableParams, Any]: """ The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation happen less frequently. Args: key: The random number generator key. params: The parameters to use. ns_results: The nested sampling results to use. desc: progress bar description Returns: The updated parameters """ # next_power_2 pad num_samples = int(ns_results.total_num_samples) n = next_power_2(num_samples) print(f"Running M-step ({num_samples} samples padded to {n})... {desc}") def _pad_to_n(x, fill_value, dtype): if x.shape[0] == n: return x return jnp.concatenate([x, jnp.full((n - x.shape[0],) + x.shape[1:], fill_value, dtype)], axis=0) log_weights = ns_results.log_dp_mean - ns_results.log_L_samples + ns_results.log_Z_mean data = MStepData( U_samples=_pad_to_n(ns_results.U_samples, 0.5, mp_policy.measure_dtype), log_weights=_pad_to_n(log_weights, -jnp.inf, mp_policy.measure_dtype) ) last_params = params epoch = 0 log_Z = None while epoch < self.max_num_epochs: params, log_Z = self._m_step(key=key, params=params, data=data) l_oo = jax.tree.map(lambda x, y: float(jnp.max(jnp.abs(x - y))) if np.size(x) > 0 else 0., last_params, params) last_params = params print(f"{desc}: Epoch {epoch}: log_Z={log_Z}, l_oo={l_oo}") if all(_l_oo < self.gtol for _l_oo in jax.tree.leaves(l_oo)): break epoch += 1 return params, log_Z
[docs] def train(self, num_steps: int = 10, params: Optional[MutableParams] = None) -> \ Tuple[ NestedSamplerResults, MutableParams]: """ Train the model using EM for num_steps. Args: num_steps: The number of steps to train for, or until convergence. params: The initial parameters to use. If None, then the model's params are used. Returns: The trained parameters. """ if params is None: params = self.model.params log_Z = -jnp.inf # Initialize the progress bar with a description ns_results = None for step in range(num_steps): key_e_step, key_m_step = random.split(random.PRNGKey(step), 2) # Execute the e_step if ns_results is None: desc = f"Step {step}: Initial run" else: desc = f"Step {step}: log Z = {ns_results.log_Z_mean:.4f} +- {ns_results.log_Z_uncert:.4f}" ns_results = self.e_step(key=key_e_step, params=params, desc=desc) # Update progress bar description # Check termination condition log_Z_change = jnp.abs(ns_results.log_Z_mean - log_Z) if log_Z_change < self.log_Z_atol: desc = ( f"Convergence achieved at step {step}, " f"due to delta log_Z {log_Z_change} < log_Z_atol {self.log_Z_atol}." ) break relative_atol = float(self.log_Z_ftol * ns_results.log_Z_uncert) if log_Z_change < relative_atol: desc = ( f"Convergence achieved at step {step}, " f"due to log_Z {log_Z_change} < log_Z_ftol * log_Z_uncert {relative_atol}." ) break # Update log_Z and log_Z_uncert values log_Z = ns_results.log_Z_mean # Execute the m_step desc = ( f"Step {step}: log Z = {ns_results.log_Z_mean:.4f} +- {ns_results.log_Z_uncert:.4f}" ) params, log_Z_opt = self.m_step(key=key_m_step, params=params, ns_results=ns_results, desc=desc) if ns_results is None: raise RuntimeError("No results were computed.") return ns_results, params