log_semiring

jaxns.internals.log_semiring

Module Contents

logaddexp(x1, x2)[source]

Equivalent to logaddexp but supporting complex arguments.

see np.logaddexp

signed_logaddexp(log_abs_val1, sign1, log_abs_val2, sign2)[source]

Equivalent of logaddexp but for signed quantities too. Broadcasting supported.

Parameters:
  • log_abs_val1 – Logarithm of absolute value of val1, \(\log(|x_1|)\)

  • sign1 – Sign of val1, \(\mathrm{sign}(x_1)\)

  • log_abs_val2 – Logarithm of absolute value of val2, \(\log(|x_2|)\)

  • sign2 – Sign of val2, \(\mathrm{sign}(x_2)\)

Returns:

(\(\log(|x_1+x_2|)\), \(\mathrm{sign}(x_1+x_2)\))

cumulative_logsumexp(u, sign=None, reverse=False, axis=0)[source]
class LogSpace(log_abs_val, sign=None)[source]

Bases: object

Parameters:
  • log_abs_val (Union[jax.numpy.ndarray, float]) –

  • sign (Union[jax.numpy.ndarray, float]) –

property dtype[source]
property log_abs_val[source]
property sign[source]
property value[source]
property signed_log[source]
property size[source]
__neg__()[source]
__add__(other)[source]

Implements addition in log space

log(exp(log_A) + exp(log_B))

Parameters:

other – ndarray or LogSpace, if ndarray assumed to be log(B)

Returns:

LogSpace

__sub__(other)[source]

Implements addition in log space

log(exp(log_A) - exp(log_B))

Parameters:

other – ndarray or LogSpace, if ndarray assumed to be log(B)

Returns:

LogSpace

__mul__(other)[source]

Implements addition in log space

log(exp(log_A) * exp(log_B))

Parameters:

other – ndarray or LogSpace, if ndarray assumed to be log(B)

Returns:

LogSpace

__repr__()[source]

Return repr(self).

sum(axis=-1, keepdims=False)[source]
nansum(axis=-1, keepdims=False)[source]
cumsum(axis=0, reverse=False)[source]
cumprod(axis=0)[source]
mean(axis=-1, keepdims=False)[source]
var(axis=-1, keepdims=False)[source]
log()[source]
exp()[source]
sqrt()[source]
abs()[source]
diff()[source]
square()[source]
argmax()[source]
maximum(other)[source]
Parameters:

other (LogSpace) –

minimum(other)[source]
Parameters:

other (LogSpace) –

max()[source]
min()[source]
concatenate(other, axis=0)[source]
Parameters:

other (LogSpace) –

__getitem__(item)[source]
__gt__(other)[source]

Return self>value.

__lt__(other)[source]

Return self<value.

__ge__(other)[source]

Return self>=value.

__le__(other)[source]

Return self<=value.

__pow__(n)[source]

Implements power in log space

log(exp(log_A)**n)

Parameters:

n – int or float

Returns:

LogSpace

__truediv__(other)[source]

Implements addition in log space

log(exp(log_A) / exp(log_B))

Parameters:

other – ndarray or LogSpace, if ndarray assumed to be log(B)

Returns:

LogSpace

is_complex(a)[source]
normalise_log_space(x, norm_type='sum')[source]

Safely normalise a LogSpace, accounting for zero-sum.

Parameters:
  • x (LogSpace) – LogSpace to normalise

  • norm_type (Literal[sum, max]) – ‘sum’ or ‘max’ normalisation

Returns:

normalised LogSpace

Return type:

LogSpace