jaxify

jaxns.framework.jaxify

Module Contents

jaxify_likelihood(log_likelihood, vectorised=False)[source]

Wraps a non-JAX log likelihood function.

Parameters:
  • log_likelihood (Callable[Ellipsis, numpy.ndarray]) – a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood.

  • vectorised (bool) – if True then the log_likelihood must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle.

Returns:

A JAX-compatible log-likelihood function.

Return type:

jaxns.internals.types.LikelihoodType