import dataclasses
import logging
from functools import partial
from typing import Tuple, Dict, Any, Optional, NamedTuple
import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp, random
from jax._src.scipy.special import logsumexp
from jaxopt import NonlinearCG, ArmijoSGD
from tqdm import tqdm
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.log_semiring import LogSpace
try:
import haiku as hk
except ImportError:
print("You must `pip install dm-haiku` first.")
raise
try:
import optax
except ImportError:
print("You must `pip install optax` first.")
raise
from jaxns import DefaultNestedSampler, Model
from jaxns.internals.types import TerminationCondition, NestedSamplerResults, StaticStandardNestedSamplerState, \
IntArray, PRNGKey, float_type
__all__ = [
'EvidenceMaximisation'
]
tfpd = tfp.distributions
tfpk = tfp.math.psd_kernels
logger = logging.getLogger('jaxns')
class MStepData(NamedTuple):
U_samples: jnp.ndarray
log_weights: jnp.ndarray
# log_dp_mean: jnp.ndarray
# log_L_samples: jnp.ndarray
# log_Z_mean: jnp.ndarray
def next_power_2(x: int) -> int:
"""
Next largest power of 2.
Args:
x: int
Returns:
next largest n**2
"""
return int(2 ** np.ceil(np.log2(x)))
@dataclasses.dataclass(eq=False)
[docs]
class EvidenceMaximisation:
"""
Evidence Maximisation class, that implements the E and M steps. Iteratively computes the evidence and maximises it.
Args:
model: The model to train.
ns_kwargs: The keyword arguments to pass to the nested sampler. Needs at least `max_samples`.
max_num_epochs: The maximum number of epochs to run M-step for.
gtol: The parameter tolerance for the M-step. End when all parameters change by less than gtol.
log_Z_ftol, log_Z_atol: The tolerances for the change in the evidence as function of log_Z_uncert.
Terminate if the change in log_Z is less than max(log_Z_ftol * log_Z_uncert, log_Z_atol).
batch_size: The batch size to use for the M-step.
momentum: The momentum to use for the M-step.
termination_cond: The termination condition to use for the nested sampler.
verbose: Whether to print progress verbosely.
"""
[docs]
ns_kwargs: Dict[str, Any]
[docs]
max_num_epochs: int = 50
[docs]
log_Z_atol: float = 1e-4
[docs]
termination_cond: Optional[TerminationCondition] = None
[docs]
def __post_init__(self):
self._e_step = self._create_e_step()
self._m_step = self._create_m_step_stochastic()
def _create_e_step(self):
"""
Create a compiled function that runs nested sampling and returns trimmed results.
Returns:
A compiled function that runs nested sampling and returns trimmed results.
"""
def _ns_solve(params: hk.MutableParams, key: random.PRNGKey) -> Tuple[
IntArray, StaticStandardNestedSamplerState]:
model = self.model(params=params)
ns = DefaultNestedSampler(model=model, **self.ns_kwargs)
termination_reason, state = ns(key, self.termination_cond)
return termination_reason, state
# Ahead of time compile the function
ns_compiled = jax.jit(_ns_solve).lower(self.model.params, random.PRNGKey(42)).compile()
def _e_step(key: PRNGKey, params: hk.MutableParams, p_bar: tqdm) -> NestedSamplerResults:
p_bar.set_description(f"Running E-step... {p_bar.desc}")
termination_reason, state = ns_compiled(params, key)
ns = DefaultNestedSampler(model=self.model(params=params), **self.ns_kwargs)
# Trim now
return ns.to_results(termination_reason=termination_reason, state=state, trim=True)
return _e_step
[docs]
def e_step(self, key: PRNGKey, params: hk.MutableParams, p_bar: tqdm) -> NestedSamplerResults:
"""
The E-step is just nested sampling.
Args:
key: The random number generator key.
params: The parameters to use.
p_bar: progress bar
Returns:
The nested sampling results.
"""
# The E-step is just nested sampling
return self._e_step(key, params, p_bar)
def _m_step_iterator(self, key: PRNGKey, data: MStepData):
num_samples = int(data.U_samples.shape[0])
permutation = jax.random.permutation(key, num_samples)
num_batches = num_samples // self.batch_size
if num_batches == 0:
raise RuntimeError("Batch size is too large for number of samples.")
for i in range(num_batches):
perm = permutation[i * self.batch_size:(i + 1) * self.batch_size]
batch = MStepData(
U_samples=data.U_samples[perm],
log_weights=data.log_weights[perm]
)
yield batch
def _create_m_step_stochastic(self):
def log_evidence(params: hk.MutableParams, data: MStepData):
# Compute the log evidence
model = self.model(params=params)
# To make manageable, we could do chunked_pmap
log_dZ = jax.vmap(
lambda U, log_weight: model.forward(U) + log_weight
)(data.U_samples, data.log_weights)
# We add the log_Z_mean because log_dp_mean is normalised
log_Z = logsumexp(log_dZ)
return log_Z
def loss(params: hk.MutableParams, data: MStepData):
log_Z, grad = jax.value_and_grad(log_evidence, argnums=0)(params, data)
obj = -log_Z
grad = jax.tree_map(jnp.negative, grad)
aux = (log_Z,)
if self.verbose:
jax.debug.print("log_Z={log_Z}", log_Z=log_Z)
return (obj, aux), grad
solver = ArmijoSGD(
fun=loss,
has_aux=True,
value_and_grad=True,
jit=True,
unroll=False,
verbose=self.verbose,
momentum=self.momentum
)
def _m_step_stochastic(key: PRNGKey, params: hk.MutableParams, data: MStepData) -> Tuple[hk.MutableParams, Any]:
"""
The M-step is just evidence maximisation.
Args:
key: The random number generator key.
params: The parameters to use.
data: The data to use.
Returns:
The updated parameters.
"""
# The M-step is just evidence maximisation
iterator = self._m_step_iterator(key, data)
opt_results = solver.run_iterator(init_params=params, iterator=iterator)
return opt_results.params, opt_results.state.aux
return _m_step_stochastic
def _create_m_step(self):
def log_evidence(params: hk.MutableParams, data: MStepData):
# Compute the log evidence
model = self.model(params=params)
def op(log_Z, data):
log_dZ = model.forward(data.U_samples) + data.log_weights
return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val
log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data)
return log_Z
def loss(params: hk.MutableParams, data: MStepData):
log_Z, grad = jax.value_and_grad(log_evidence, argnums=0)(params, data)
obj = -log_Z
grad = jax.tree_map(jnp.negative, grad)
aux = (log_Z,)
if self.verbose:
jax.debug.print("log_Z={log_Z}", log_Z=log_Z)
return (obj, aux), grad
solver = NonlinearCG(
fun=loss,
has_aux=True,
value_and_grad=True,
jit=True,
unroll=False,
verbose=self.verbose
)
@partial(jax.jit, static_argnums=(0,))
def _m_step(key: PRNGKey, params: hk.MutableParams, data: MStepData) -> Tuple[hk.MutableParams, Any]:
"""
The M-step is just evidence maximisation.
Args:
params: The parameters to use.
data: The data to use.
Returns:
The updated parameters and the negative log evidence.
"""
opt_results = solver.run(init_params=params, data=data)
return opt_results.params, opt_results.state.aux
return _m_step
[docs]
def m_step(self, key: PRNGKey, params: hk.MutableParams, ns_results: NestedSamplerResults, p_bar: tqdm) -> Tuple[
hk.MutableParams, Any]:
"""
The M-step is just evidence maximisation. We pad the data to the next power of 2, to make JIT compilation
happen less frequently.
Args:
key: The random number generator key.
params: The parameters to use.
ns_results: The nested sampling results to use.
p_bar: progress bar
Returns:
The updated parameters
"""
# next_power_2 pad
num_samples = int(ns_results.total_num_samples)
n = next_power_2(num_samples)
p_bar.set_description(f"Running M-step ({num_samples} samples padded to {n})... {p_bar.desc}")
def _pad_to_n(x, fill_value, dtype):
return jnp.concatenate([x, jnp.full((n - x.shape[0],) + x.shape[1:], fill_value, dtype)], axis=0)
log_weights = ns_results.log_dp_mean - ns_results.log_L_samples + ns_results.log_Z_mean
data = MStepData(
U_samples=_pad_to_n(ns_results.U_samples, 0.5, float_type),
log_weights=_pad_to_n(log_weights, -jnp.inf, float_type)
)
desc = p_bar.desc
last_params = params
epoch = 0
log_Z = None
while epoch < self.max_num_epochs:
params, (log_Z,) = self._m_step(key=key, params=params, data=data)
l_oo = jax.tree_map(lambda x, y: jnp.max(jnp.abs(x - y)), last_params, params)
last_params = params
p_bar.set_description(f"{desc}: Epoch {epoch}: log_Z={log_Z}, l_oo={l_oo}")
if all(_l_oo < self.gtol for _l_oo in jax.tree_leaves(l_oo)):
break
epoch += 1
return params, log_Z
[docs]
def train(self, num_steps: int = 10, params: Optional[hk.MutableParams] = None) -> \
Tuple[
NestedSamplerResults, hk.MutableParams]:
"""
Train the model using EM for num_steps.
Args:
num_steps: The number of steps to train for, or until convergence.
params: The initial parameters to use. If None, then the model's params are used.
Returns:
The trained parameters.
"""
if params is None:
params = self.model.params
log_Z = -jnp.inf
# Initialize the progress bar with a description
p_bar = tqdm(range(num_steps), desc="Processing Steps", dynamic_ncols=True)
ns_results = None
for step in p_bar:
key_e_stek, key_m_step = random.split(random.PRNGKey(step), 2)
# Execute the e_step
if ns_results is None:
p_bar.set_description(f"Step {step}: Initial run")
else:
p_bar.set_description(
f"Step {step}: log Z = {ns_results.log_Z_mean:.4f} +- {ns_results.log_Z_uncert:.4f}"
)
ns_results = self.e_step(key=key_e_stek, params=params, p_bar=p_bar)
# Update progress bar description
# Check termination condition
log_Z_change = jnp.abs(ns_results.log_Z_mean - log_Z)
if log_Z_change < max(self.log_Z_ftol * ns_results.log_Z_uncert, self.log_Z_atol):
p_bar.set_description(f"Convergence achieved at step {step}.")
break
# Update log_Z and log_Z_uncert values
log_Z = ns_results.log_Z_mean
# Execute the m_step
p_bar.set_description(
f"Step {step}: log Z = {ns_results.log_Z_mean:.4f} +- {ns_results.log_Z_uncert:.4f}"
)
params, log_Z_opt = self.m_step(key=key_m_step, params=params, ns_results=ns_results, p_bar=p_bar)
if ns_results is None:
raise RuntimeError("No results were computed.")
return ns_results, params