Dual Moons likelihood

Dual moons is a tricky likelihood that mimics physically challenging models.

[1]:

import os os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=6" from jax import config config.update("jax_enable_x64", True) import pylab as plt import tensorflow_probability.substrates.jax as tfp from jax import random, numpy as jnp from jax import vmap from jaxns import NestedSampler from jaxns import Model from jaxns import Prior from jaxns import bruteforce_evidence tfpd = tfp.distributions
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2]:
from jax._src.scipy.special import logsumexp


def log_likelihood(x):
    """
    Dual moon log-likelihood.
    """

    def fasym(t):
        return jnp.piecewise(
            t, [t < 0, t >= 0],
            [lambda x: (x / 2) ** 2,
             lambda x: (x / 0.5) ** 4
             ])

    xx = x[..., :1]
    xy = x[..., 1:2]
    rho1 = jnp.sqrt((xx) ** 2 + (xy) ** 2)
    theta1 = jnp.arctan2(xy, xx)
    term1 = ((rho1 - 2) / 0.05) ** 2
    term2 = fasym(theta1 - jnp.pi / 2) / 0.2

    rho2 = jnp.sqrt((xx) ** 2 + (xy) ** 2)
    theta2 = jnp.arctan2(xy, xx)
    term12 = ((rho2 - 2) / 0.05) ** 2
    term22 = fasym(theta1 + jnp.pi / 3) / 0.2

    pe = logsumexp(-jnp.stack([term1 + term2, term12 + term22]), axis=0)

    return pe.squeeze()


def prior_model():
    x = yield Prior(tfpd.Uniform(low=-4 * jnp.ones(2), high=4. * jnp.ones(2)), name='x')
    return x


model = Model(prior_model=prior_model,
              log_likelihood=log_likelihood)

log_Z_true = bruteforce_evidence(model=model, S=250)
print(f"True log(Z)={log_Z_true}")

True log(Z)=-5.104862238439268
[3]:

u_vec = jnp.linspace(0., 1., 250) args = jnp.stack([x.flatten() for x in jnp.meshgrid(*[u_vec] * model.U_ndims, indexing='ij')], axis=-1) # The `prepare_func_args(log_likelihood)` turns the log_likelihood into a function that nicely accepts **kwargs lik = vmap(model.forward)(args).reshape((u_vec.size, u_vec.size)) plt.imshow(jnp.exp(lik), origin='lower', extent=(-4, 4, -4, 4), cmap='jet') plt.colorbar() plt.show()
../_images/examples_dual_moon_3_0.png
[4]:


# Create the nested sampler class. In this case without any tuning. ns = NestedSampler(model=model, max_samples=1e5) termination_reason, state = ns(random.PRNGKey(42)) results = ns.to_results(termination_reason=termination_reason, state=state)
Running over 6 devices.
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/internals/mixed_precision.py:61: UserWarning: Expected integer type, got bool, at public.py:173 in to_results -> sharded_static.py:475 in _to_results -> mixed_precision.py:105 in cast_to_count -> mixed_precision.py:67 in _cast_integer_to -> tree.py:61 in map -> tree_util.py:321 in tree_map -> tree_util.py:321 in <genexpr> -> mixed_precision.py:61 in conditional_cast.
  warnings.warn(f"Expected integer type, got {x.dtype}, {get_grandparent_info()}.")
[5]:
# We can use the summary utility to display results
ns.summary(results)
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 49180
samples: 690
phantom samples: 0
likelihood evals / sample: 71.3
phantom fraction (%): 0.0%
--------
logZ=-4.66 +- 0.3
max(logL)=0.0
H=-3.91
ESS=69
--------
x[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x[0]: 0.23 +- 0.7 | -0.54 / 0.22 / 1.13 | 1.03 | 1.03
x[1]: -0.0 +- 1.9 | -2.0 / -0.3 / 2.0 | -1.7 | -1.7
--------
[6]:
# We plot useful diagnostics and a distribution cornerplot
ns.plot_diagnostics(results)
ns.plot_cornerplot(results)
../_images/examples_dual_moon_6_0.png
../_images/examples_dual_moon_6_1.png