Constant Likelihood

Managing plateaus in nested sampling is tricky because the measure of enclosed prior volume is assumed to be monotonically decreasing. In this simple model, the model is simply,

\(L(x) = P(y | x) = 1\)

and

\(P(x) = \mathcal{U}[x \mid 0, 1]\).

The analytic evidence for this model is,

\(Z = P(y) = \int_\mathcal{X} L(x) p(x) \,\mathrm{d} x = 1\)

[1]:

import tensorflow_probability.substrates.jax as tfp from jax import random from jaxns import Model, Prior, TerminationCondition tfpd = tfp.distributions
INFO[2023-12-21 13:21:04,431]: Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO[2023-12-21 13:21:04,432]: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO[2023-12-21 13:21:04,432]: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
WARNING[2023-12-21 13:21:04,433]: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2]:

def log_likelihood(theta): return 0. def prior_model(): x = yield Prior(tfpd.Uniform(0., 1.), name='x') return x model = Model(prior_model=prior_model, log_likelihood=log_likelihood) log_Z_true = 0. print(f"True log(Z)={log_Z_true}")
True log(Z)=0.0
[3]:
from jaxns import DefaultNestedSampler

# Create the nested sampler class. In this case without any tuning.
exact_ns = DefaultNestedSampler(model=model, max_samples=1e4)

termination_reason, state = exact_ns(random.PRNGKey(42))
results = exact_ns.to_results(termination_reason=termination_reason, state=state)

[4]:
# We can use the summary utility to display results
exact_ns.summary(results)
--------
Termination Conditions:
All live-points are on a single plateau (potential numerical errors, consider 64-bit)
--------
likelihood evals: 30
samples: 30
phantom samples: 0
likelihood evals / sample: 1.0
phantom fraction (%): 0.0%
--------
logZ=-0.05 +- 0.036
H=-0.05
ESS=15
--------
x: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x: 0.48 +- 0.26 | 0.16 / 0.52 / 0.79 | 0.71 | 0.71
--------
[5]:
# We plot useful diagnostics and a distribution cornerplot
exact_ns.plot_diagnostics(results)
exact_ns.plot_cornerplot(results)
../_images/examples_constant_likelihood_5_0.png
../_images/examples_constant_likelihood_5_1.png