jaxify
jaxns.framework.jaxify
Module Contents
- jaxify_likelihood(log_likelihood, vectorised=False)[source]
Wraps a non-JAX log likelihood function.
- Parameters:
log_likelihood (Callable[Ellipsis, numpy.ndarray]) – a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood.
vectorised (bool) – if True then the log_likelihood must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle.
- Returns:
A JAX-compatible log-likelihood function.
- Return type:
jaxns.internals.types.LikelihoodType