Source code for jaxns.internals.random

from typing import Optional

import jax
from jax import random, numpy as jnp
from jax.scipy import special

from jaxns.internals.log_semiring import cumulative_logsumexp
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import FloatArray, IntArray, PRNGKey

__all__ = ['random_ortho_matrix',
           'resample_indicies']


[docs] def random_ortho_matrix(key, n, special_orthogonal: bool = False): """ Samples a random orthonormal n by n matrix from Stiefels manifold. From https://stackoverflow.com/a/38430739 Args: key: PRNG seed n: Size of matrix, draws from O(num_options) group. Returns: random [num_options,num_options] matrix with determinant = +-1 """ H = random.normal(key, shape=(n, n), dtype=mp_policy.measure_dtype) Q, R = jnp.linalg.qr(H) if special_orthogonal: R *= jnp.sign(R) Q = Q @ jnp.diag(jnp.sign(jnp.diag(R))) return Q
[docs] def resample_indicies(key: PRNGKey, log_weights: Optional[FloatArray] = None, S: Optional[int] = None, replace: bool = True, num_total: Optional[int] = None) -> IntArray: """ Get resample indicies according to a given weighting, with or without replacement. Args: key: PRNGKey log_weights: Optional log weights S: Optional number of samples. Computes effective sample size from log weights if not given. replace: whether to use replacement or not. num_total: Optional total sample size to use, must be given if `replace=False` and `log_weights=None` Returns: index array given the take indicies to resample at. """ if S is None: if log_weights is None: raise ValueError("Need log_weights if S is not given.") # ESS = (sum w)^2 / sum w^2 S = int(jnp.exp(2. * special.logsumexp(log_weights) - special.logsumexp(2. * log_weights))) if replace: if log_weights is not None: # use cumulative_logsumexp because some log_weights could be really small log_p_cuml = cumulative_logsumexp(log_weights) log_r = log_p_cuml[-1] + jnp.log(1. - random.uniform(key, (S,))) idx = jnp.searchsorted(log_p_cuml, log_r) else: if num_total is None: raise ValueError("Need num_total if log_weights is None.") log_p_cuml = jnp.log(jnp.arange(num_total)) log_r = log_p_cuml[-1] + jnp.log(1. - random.uniform(key, (S,))) idx = jnp.searchsorted(log_p_cuml, log_r) else: if log_weights is not None: g = -random.gumbel(key, shape=log_weights.shape) - log_weights else: if num_total is None: raise ValueError("Need num_total if log_weights is None.") g = -random.gumbel(key, shape=(num_total,)) idx = jnp.argsort(g)[:S] return idx
def sample_uniformly_masked(key, v, select_mask, num_samples: int, squeeze: bool = False): # If no satisfied samples, then chooses randomly from them. Should never happen, but good to know. log_weights = jnp.where(select_mask, 0., -jnp.inf) sample_idxs = resample_indicies(key, log_weights=log_weights, S=num_samples, replace=True) if squeeze: sample_idxs = jnp.squeeze(sample_idxs) return jax.tree.map(lambda x: x[sample_idxs], v)