Source code for jaxns.internals.log_semiring

from typing import Union, Literal

from jax import numpy as jnp, lax
from jax.scipy.special import logsumexp

from jaxns.internals.types import SignedLog, float_type


[docs] def logaddexp(x1, x2): """ Equivalent to logaddexp but supporting complex arguments. see np.logaddexp """ if is_complex(x1) or is_complex(x2): select1 = x1.real > x2.real amax = jnp.where(select1, x1, x2) delta = jnp.where(select1, x2 - x1, x1 - x2) return jnp.where(jnp.isnan(delta), x1 + x2, # NaNs or infinities of the same sign. amax + jnp.log1p(jnp.exp(delta))) else: return jnp.logaddexp(x1, x2)
[docs] def signed_logaddexp(log_abs_val1, sign1, log_abs_val2, sign2): r""" Equivalent of logaddexp but for signed quantities too. Broadcasting supported. Args: log_abs_val1: Logarithm of absolute value of val1, :math:`\log(|x_1|)` sign1: Sign of val1, :math:`\mathrm{sign}(x_1)` log_abs_val2: Logarithm of absolute value of val2, :math:`\log(|x_2|)` sign2: Sign of val2, :math:`\mathrm{sign}(x_2)` Returns: (:math:`\log(|x_1+x_2|)`, :math:`\mathrm{sign}(x_1+x_2)`) """ amax = jnp.maximum(log_abs_val1, log_abs_val2) signmax = jnp.where(log_abs_val1 > log_abs_val2, sign1, sign2) delta = -jnp.abs(log_abs_val2 - log_abs_val1) # nan iff inf - inf sign = sign1 * sign2 return jnp.where(jnp.isnan(delta), log_abs_val1 + log_abs_val2, # NaNs or infinities of the same sign. amax + jnp.log1p(sign * jnp.exp(delta))), signmax
[docs] def cumulative_logsumexp(u, sign=None, reverse=False, axis=0): if sign is not None: u, sign = jnp.broadcast_arrays(u, sign) def body(state, X): if sign is not None: (u, u_sign) = X (accumulant, accumulant_sign) = state new_accumulant, new_accumulant_sign = signed_logaddexp(accumulant, accumulant_sign, u, u_sign) return (new_accumulant, accumulant_sign), (new_accumulant, accumulant_sign) else: u = X accumulant = state new_accumulant = jnp.logaddexp(accumulant, u) return new_accumulant, new_accumulant if sign is not None: if axis != 0: sign = jnp.swapaxes(sign, axis, 0) u = jnp.swapaxes(u, axis, 0) state = (-jnp.inf * jnp.ones(u.shape[1:], dtype=u.dtype), jnp.ones(u.shape[1:], dtype=u.dtype)) X = (u, sign) else: if axis != 0: u = jnp.swapaxes(u, axis, 0) state = -jnp.inf * jnp.ones(u.shape[1:], dtype=u.dtype) X = u _, result = lax.scan(body, state, X, reverse=reverse) if sign is not None: v, v_sign = result if axis != 0: v = jnp.swapaxes(v, axis, 0) v_sign = jnp.swapaxes(v_sign, axis, 0) return v, v_sign else: v = result if axis != 0: v = jnp.swapaxes(v, axis, 0) return v
[docs] class LogSpace(object): def __init__(self, log_abs_val: Union[jnp.ndarray, float], sign: Union[jnp.ndarray, float] = None): self._log_abs_val = jnp.asarray(log_abs_val, float_type) if sign is None: self._sign = jnp.asarray(1., float_type) self._naked = True else: self._sign = jnp.asarray(sign, float_type) self._naked = False @property
[docs] def dtype(self): return self.log_abs_val.dtype
@property
[docs] def log_abs_val(self): return self._log_abs_val
@property
[docs] def sign(self): return self._sign
@property
[docs] def value(self): if self._naked: return jnp.exp(self.log_abs_val) return self.sign * jnp.exp(self.log_abs_val)
[docs] def __neg__(self): if self._naked: return LogSpace(self.log_abs_val, -jnp.ones_like(self.log_abs_val)) return LogSpace(self.log_abs_val, -self.sign)
[docs] def __add__(self, other): """ Implements addition in log space log(exp(log_A) + exp(log_B)) Args: other: ndarray or LogSpace, if ndarray assumed to be log(B) Returns: LogSpace """ if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: # no coefficients return LogSpace(jnp.logaddexp(self._log_abs_val, other._log_abs_val)) return LogSpace(*signed_logaddexp(self._log_abs_val, self._sign, other._log_abs_val, other._sign))
[docs] def __sub__(self, other): """ Implements addition in log space log(exp(log_A) - exp(log_B)) Args: other: ndarray or LogSpace, if ndarray assumed to be log(B) Returns: LogSpace """ if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") return LogSpace(*signed_logaddexp(self._log_abs_val, self._sign, other._log_abs_val, -other._sign))
[docs] def __mul__(self, other): """ Implements addition in log space log(exp(log_A) * exp(log_B)) Args: other: ndarray or LogSpace, if ndarray assumed to be log(B) Returns: LogSpace """ if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: # no coefficients return LogSpace(self._log_abs_val + other._log_abs_val) return LogSpace(self._log_abs_val + other._log_abs_val, self._sign * other._sign)
[docs] def __repr__(self): if self._naked: return f"LogSpace({self.log_abs_val})" return f"LogSpace({self.log_abs_val}, {self.sign})"
[docs] def sum(self, axis=-1, keepdims=False): if not self._naked: # no coefficients return LogSpace(*logsumexp(self.log_abs_val, b=self.sign, axis=axis, keepdims=keepdims, return_sign=True)) return LogSpace(logsumexp(self._log_abs_val, axis=axis, keepdims=keepdims))
[docs] def nansum(self, axis=-1, keepdims=False): log_abs_val = jnp.where(jnp.isnan(self.log_abs_val), -jnp.inf, self.log_abs_val) if not self._naked: # no coefficients return LogSpace(*logsumexp(log_abs_val, b=self.sign, axis=axis, keepdims=keepdims, return_sign=True)) return LogSpace(logsumexp(log_abs_val, axis=axis, keepdims=keepdims))
[docs] def cumsum(self, axis=0, reverse=False): if not self._naked: # no coefficients return LogSpace(*cumulative_logsumexp(self.log_abs_val, sign=self.sign, axis=axis, reverse=reverse)) return LogSpace(cumulative_logsumexp(self._log_abs_val, axis=axis, reverse=reverse))
[docs] def cumprod(self, axis=0): if not self._naked: # no coefficients log_abs_val, sign = jnp.broadcast_arrays(self.log_abs_val, self.sign) return LogSpace(jnp.cumsum(log_abs_val, axis=axis), jnp.cumprod(sign, axis=axis)) return LogSpace(jnp.cumsum(self._log_abs_val, axis=axis))
[docs] def mean(self, axis=-1, keepdims=False): N = self._log_abs_val.shape[axis] return self.sum(axis=axis, keepdims=keepdims) / LogSpace(jnp.log(N))
[docs] def var(self, axis=-1, keepdims=False): return (self - self.mean(axis=axis, keepdims=True)).mean(axis=axis, keepdims=keepdims)
[docs] def log(self): assert self._naked return LogSpace(jnp.log(jnp.abs(self.log_abs_val)), jnp.sign(self.log_abs_val))
[docs] def exp(self): return LogSpace(self.value)
[docs] def sqrt(self): return self ** 0.5
[docs] def abs(self): return LogSpace(self.log_abs_val)
[docs] def diff(self): if self._naked: log_abs_val, sign = jnp.broadcast_arrays(self.log_abs_val, self.sign) return LogSpace(log_abs_val[1:], sign[1:]) - LogSpace(log_abs_val[:-1], sign[:-1]) else: return LogSpace(self.log_abs_val[1:]) - LogSpace(self.log_abs_val[:-1])
[docs] def square(self): return self * self
[docs] def argmax(self): return jnp.argmax(self.log_abs_val)
[docs] def maximum(self, other: "LogSpace"): assert self._naked and other._naked return LogSpace(jnp.maximum(self.log_abs_val, other.log_abs_val))
[docs] def minimum(self, other: "LogSpace"): assert self._naked and other._naked return LogSpace(jnp.minimum(self.log_abs_val, other.log_abs_val))
[docs] def max(self): assert self._naked return LogSpace(jnp.max(self.log_abs_val))
[docs] def min(self): assert self._naked return LogSpace(jnp.min(self.log_abs_val))
[docs] def concatenate(self, other: "LogSpace", axis=0): if self._naked and other._naked: return LogSpace(jnp.concatenate([self.log_abs_val, other.log_abs_val], axis=axis)) log_abs_val, sign = jnp.broadcast_arrays(self.log_abs_val, self.sign) _log_abs_val, _sign = jnp.broadcast_arrays(other.log_abs_val, other.sign) return LogSpace(jnp.concatenate([log_abs_val, _log_abs_val], axis=axis), jnp.concatenate([sign, _sign], axis=axis))
[docs] def __getitem__(self, item): if self._naked: return LogSpace(self.log_abs_val[item]) log_abs_val, sign = jnp.broadcast_arrays(self.log_abs_val, self.sign) return LogSpace(log_abs_val[item], sign[item])
@property
[docs] def signed_log(self): return SignedLog(self.log_abs_val, self.sign)
[docs] def __gt__(self, other): if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: return self.log_abs_val > other.log_abs_val return (self / other).value > 1.
[docs] def __lt__(self, other): if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: return self.log_abs_val < other.log_abs_val return (self / other).value < 1.
[docs] def __ge__(self, other): if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: return self.log_abs_val >= other.log_abs_val return (self / other).value >= 1.
[docs] def __le__(self, other): if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: return self.log_abs_val <= other.log_abs_val return (self / other).value <= 1.
@property
[docs] def size(self): if self._naked: return self.log_abs_val.size log_abs_val, sign = jnp.broadcast_arrays(self.log_abs_val, self.sign) return log_abs_val.size
[docs] def __pow__(self, n): """ Implements power in log space log(exp(log_A)**n) Args: n: int or float Returns: LogSpace """ if not isinstance(n, (int, float, jnp.ndarray)): raise NotImplementedError("Not implemented for non-int powers.") n = jnp.asarray(n, float_type) if self._naked: return LogSpace(n * self.log_abs_val) # complex values can occur if n is not even return LogSpace(n * self.log_abs_val, sign=self.sign ** n)
# for _ in range(n-1): # output = output * self # return output
[docs] def __truediv__(self, other): """ Implements addition in log space log(exp(log_A) / exp(log_B)) Args: other: ndarray or LogSpace, if ndarray assumed to be log(B) Returns: LogSpace """ if not isinstance(other, LogSpace): raise TypeError(f"Expected type {type(self)} got {type(other)}") if self._naked and other._naked: # no coefficients return LogSpace(self._log_abs_val - other._log_abs_val) return LogSpace(self._log_abs_val - other._log_abs_val, self._sign * other._sign)
[docs] def is_complex(a): return a.dtype in [jnp.complex64, jnp.complex128]
[docs] def normalise_log_space(x: LogSpace, norm_type: Literal['sum', 'max'] = 'sum') -> LogSpace: """ Safely normalise a LogSpace, accounting for zero-sum. Args: x: LogSpace to normalise norm_type: 'sum' or 'max' normalisation Returns: normalised LogSpace """ if norm_type == 'sum': norm = x.sum() elif norm_type == 'max': norm = x.max() else: raise ValueError(f"Unknown norm_type {norm_type}") x /= norm x = LogSpace(jnp.where(jnp.isneginf(norm.log_abs_val), -jnp.inf, x.log_abs_val)) return x