Source code for jaxns.nested_samplers.common.initialisation

from functools import partial
from typing import Optional, Tuple

import jax
from jax import numpy as jnp, random
from jax._src.mesh import Mesh
from jax._src.partition_spec import PartitionSpec
from jax.experimental.shard_map import shard_map

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.maps import tree_device_put
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.shrinkage_statistics import create_init_evidence_calc
from jaxns.internals.types import PRNGKey
from jaxns.nested_samplers.common.types import StaticStandardSampleCollection, \
    NestedSamplerState, TerminationRegister, LivePointCollection
from jaxns.nested_samplers.common.uniform_sample import draw_uniform_samples


[docs] def create_init_state(key: PRNGKey, num_live_points: int, max_samples: int, model: BaseAbstractModel, mesh: Optional[Mesh] = None) -> Tuple[LivePointCollection, NestedSamplerState]: """ Return an initial sample collection, that will be incremented by the sampler. Args: key: PRNGKey num_live_points: the number of live points max_samples: the maximum number of samples model: the model to use. Returns: live_point_collection: the initial live point collection state: the initial state """ def _repeat(a): return jnp.repeat(a[None], repeats=max_samples, axis=0, total_repeat_length=max_samples) sample_collection = StaticStandardSampleCollection( sender_node_idx=jnp.zeros(max_samples, dtype=mp_policy.index_dtype), log_L=jnp.full((max_samples,), jnp.inf, dtype=mp_policy.measure_dtype), U_samples=_repeat(model.U_placeholder), num_likelihood_evaluations=jnp.full((max_samples,), 0, dtype=mp_policy.count_dtype), phantom=jnp.full((max_samples,), False, dtype=jnp.bool_) ) key, sample_key = random.split(key, 2) if mesh is not None: sharded_keys = tree_device_put(jax.random.split(sample_key, num_live_points), mesh, ('shard',)) @partial(shard_map, mesh=mesh, in_specs=PartitionSpec('shard', ), out_specs=PartitionSpec('shard', ), check_rep=False) def get_init_samples(local_keys): return draw_uniform_samples( keys=local_keys, model=model, method='scan' ) live_point_samples = get_init_samples(sharded_keys) else: keys = jax.random.split(sample_key, num_live_points) live_point_samples = draw_uniform_samples( keys=keys, model=model, method='vmap' ) live_point_collection = LivePointCollection( sender_node_idx=jnp.full((num_live_points,), 0, dtype=mp_policy.index_dtype), U_sample=live_point_samples.U_sample, log_L=live_point_samples.log_L, log_L_constraint=live_point_samples.log_L_constraint, num_likelihood_evaluations=live_point_samples.num_likelihood_evaluations ) sort_indices = jnp.argsort(live_point_samples.log_L) live_point_collection = jax.tree.map(lambda x: x[sort_indices], live_point_collection) state = NestedSamplerState( key=key, next_sample_idx=jnp.asarray(0, mp_policy.index_dtype), num_samples=jnp.asarray(0, mp_policy.index_dtype), sample_collection=sample_collection ) return live_point_collection, state
[docs] def create_init_termination_register() -> TerminationRegister: """ Initialise the termination register. Returns: The initial termination register. """ zero_count = jnp.asarray(0, mp_policy.count_dtype) init_evidence_calc = create_init_evidence_calc() return TerminationRegister( num_samples_used=zero_count, evidence_calc=init_evidence_calc, evidence_calc_with_remaining=init_evidence_calc, num_likelihood_evaluations=zero_count, log_L_contour=jnp.asarray(-jnp.inf, mp_policy.measure_dtype), efficiency=jnp.asarray(0., mp_policy.measure_dtype), plateau=jnp.asarray(False, jnp.bool_), no_seed_points=jnp.asarray(False, jnp.bool_), relative_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype), absolute_spread=jnp.asarray(jnp.inf, mp_policy.measure_dtype), peak_log_XL=jnp.asarray(-jnp.inf, mp_policy.measure_dtype) )