Source code for jaxns.framework.jaxify

import warnings
from typing import Callable

import jax
import numpy as np

from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import LikelihoodType

__all__ = [
    'jaxify_likelihood'
]


[docs] def jaxify_likelihood(log_likelihood: Callable[..., np.ndarray], vectorised: bool = False) -> LikelihoodType: """ Wraps a non-JAX log likelihood function. Args: log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood. vectorised: if True then the `log_likelihood` must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle. Returns: A JAX-compatible log-likelihood function. """ warnings.warn( "You're using a non-JAX log-likelihood function. This may be slower than a JAX log-likelihood function. " "Also, you are responsible for ensuring that the function is deterministic. " "Also, you cannot use learnable parameters in the likelihood call." ) def _cond_cast(x): if isinstance(x, (jax.Array, np.ndarray)): return np.asarray(x) return x def _casted_log_likelihood(*args) -> np.ndarray: args = jax.tree.map(_cond_cast, args) # Convert all arguments to numpy arrays, as they now pass jax.Array return mp_policy.cast_to_measure(log_likelihood(*args)) def _log_likelihood(*args) -> jax.Array: # Define the expected shape & dtype of output. result_shape_dtype = jax.ShapeDtypeStruct( shape=(), dtype=mp_policy.measure_dtype ) if vectorised: return jax.pure_callback(_casted_log_likelihood, result_shape_dtype, *args, vmap_method='broadcast_all') return jax.pure_callback(_casted_log_likelihood, result_shape_dtype, *args, vmap_method=None) return _log_likelihood