Introduction to JAXNS

In this intoduction we will show how to use the proababilistic programming framework.

[1]:
import jax
import tensorflow_probability.substrates.jax as tfp

tfpd = tfp.distributions

from jaxns.framework.model import Model
from jaxns.framework.prior import Prior


def prior_model():
    mu = yield Prior(tfpd.Normal(loc=0., scale=1.))
    # Let's make sigma a parameterised variable.
    # It requires a name, but will not be collected as a Bayesian variable.
    sigma = yield Prior(tfpd.Exponential(rate=1.), name='sigma').parametrised()
    x = yield Prior(tfpd.Cauchy(loc=mu, scale=sigma), name='x')
    uncert = yield Prior(tfpd.Exponential(rate=1.), name='uncert')
    return x, uncert


def log_likelihood(x, uncert):
    return tfpd.Normal(loc=0., scale=uncert).log_prob(x)


model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

# You can sanity check the model (always a good idea when exploring)
model.sanity_check(key=jax.random.PRNGKey(0), S=100)
INFO[2024-01-08 17:01:45,714]: Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO[2024-01-08 17:01:45,714]: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO[2024-01-08 17:01:45,715]: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
WARNING[2024-01-08 17:01:45,717]: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
INFO[2024-01-08 17:01:46,975]: Sanity check...
INFO[2024-01-08 17:01:47,527]: Sanity check passed
[2]:
# Sample the prior in U-space (base measure)
U = model.sample_U(key=jax.random.PRNGKey(0))
# Transform to X-space
X = model.transform(U=U)
# Only named Bayesian prior variables are returned, the rest are treated as hidden variables.
assert set(X.keys()) == {'x', 'uncert'}

# Get the return value of the prior model, i.e. the input to the likelihood
x_sample, uncert_sample = model.prepare_input(U=U)
[3]:
# Evaluate different parts of the model
log_prob_prior = model.log_prob_prior(U)
log_prob_likelihood = model.log_prob_likelihood(U, allow_nan=False)
log_prob_joint = model.log_prob_joint(U, allow_nan=False)
[4]:
init_params = model.params


def log_prob_joint_fn(params, U):
    # Calling model with params returns a new model with the params set
    return model(params).log_prob_joint(U, allow_nan=False)


value, grad = jax.value_and_grad(log_prob_joint_fn)(init_params, U)
print(value, grad)
-8.101052 {'~': {'sigma_param': Array(-11.244353, dtype=float32)}}
[5]:
from jaxns import DefaultNestedSampler

ns = DefaultNestedSampler(model=model, max_samples=1e5)

# Run the sampler
termination_reason, state = ns(jax.random.PRNGKey(42))
# Get the results
results = ns.to_results(termination_reason=termination_reason, state=state)
[6]:
from jaxns import summary, plot_diagnostics, plot_cornerplot

summary(results)
plot_diagnostics(results, save_name='intro_diagnostics.png')
plot_cornerplot(results, save_name='intro_cornerplot.png')
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 149918
samples: 3780
phantom samples: 1710
likelihood evals / sample: 39.7
phantom fraction (%): 45.2%
--------
logZ=-1.65 +- 0.15
H=-1.13
ESS=132
--------
uncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
uncert: 0.68 +- 0.58 | 0.13 / 0.48 / 1.37 | 0.0 | 0.0
--------
x: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x: 0.07 +- 0.62 | -0.57 / 0.06 / 0.73 | 0.0 | 0.0
--------
../_images/examples_intro_example_6_1.png
../_images/examples_intro_example_6_2.png
[7]:
from jaxns import resample

samples = resample(
    key=jax.random.PRNGKey(0),
    samples=results.samples,
    log_weights=results.log_dp_mean,
    S=1000,
    replace=True
)