Source code for jaxns.internals.constraint_bijections

import jax
import jax.numpy as jnp
from jax import lax

__all__ = [
    "quick_unit",
    "quick_unit_inverse"
]


[docs] def quick_unit(x: jax.Array) -> jax.Array: """ Quick approximation to the sigmoid. Args: x: jax.Array value in (-inf, inf) open interval Returns: value in (0, 1) in open interval """ return 0.5 * (x / (1 + lax.abs(x)) + 1)
[docs] def quick_unit_inverse(y: jax.Array) -> jax.Array: """ Inverse of quick_unit. Args: y: jax.Array value in (0, 1) open interval Returns: value in (-inf, inf) in open interval """ twoy = y + y return jnp.where( y >= 0.5, (1 - twoy) / (twoy - 2), 1 - lax.reciprocal(twoy) )