[1]:
import jax
import tensorflow_probability.substrates.jax as tfp
tfpd = tfp.distributions
from jaxns.framework.model import Model
from jaxns.framework.prior import Prior
def prior_model():
mu = yield Prior(tfpd.Normal(loc=0., scale=1.))
# Let's make sigma a parameterised variable.
# It requires a name, but will not be collected as a Bayesian variable.
sigma = yield Prior(tfpd.Exponential(rate=1.), name='sigma').parametrised()
x = yield Prior(tfpd.Cauchy(loc=mu, scale=sigma), name='x')
uncert = yield Prior(tfpd.Exponential(rate=1.), name='uncert')
return x, uncert
def log_likelihood(x, uncert):
return tfpd.Normal(loc=0., scale=uncert).log_prob(x)
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
# You can sanity check the model (always a good idea when exploring)
model.sanity_check(key=jax.random.PRNGKey(0), S=100)
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/internals/mixed_precision.py:14: UserWarning: JAX x64 is not enabled. Setting it now. Check for errors.
warnings.warn("JAX x64 is not enabled. Setting it now. Check for errors.")
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
INFO:jaxns:Sanity check...
INFO:jaxns:Sanity check passed
[2]:
# Sample the prior in U-space (base measure)
U = model.sample_U(key=jax.random.PRNGKey(0))
# Transform to X-space
X = model.transform(U=U)
# Only named Bayesian prior variables are returned, the rest are treated as hidden variables.
assert set(X.keys()) == {'x', 'uncert'}
# Get the return value of the prior model, i.e. the input to the likelihood
x_sample, uncert_sample = model.prepare_input(U=U)
[3]:
# Evaluate different parts of the model
log_prob_prior = model.log_prob_prior(U)
log_prob_likelihood = model.log_prob_likelihood(U, allow_nan=False)
log_prob_joint = model.log_prob_joint(U, allow_nan=False)
[4]:
init_params = model.params
def log_prob_joint_fn(params, U):
# Calling model with params returns a new model with the params set
return model(params).log_prob_joint(U, allow_nan=False)
value, grad = jax.value_and_grad(log_prob_joint_fn)(init_params, U)
print(value, grad)
-9.372949645737593 {'sigma_param': Array(-5.82616862, dtype=float64)}
[5]:
from jaxns import NestedSampler
ns = NestedSampler(model=model, max_samples=1e5)
# Run the sampler
termination_reason, state = ns(jax.random.PRNGKey(42))
# Get the results
results = ns.to_results(termination_reason=termination_reason, state=state)
[6]:
from jaxns import summary, plot_diagnostics, plot_cornerplot
summary(results)
plot_diagnostics(results, save_name='intro_diagnostics.png')
plot_cornerplot(results, save_name='intro_cornerplot.png')
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 192008
samples: 2025
phantom samples: 0
likelihood evals / sample: 94.8
phantom fraction (%): 0.0%
--------
logZ=-1.41 +- 0.13
max(logL)=16.04
H=-1.09
ESS=170
--------
uncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
uncert: 0.74 +- 0.83 | 0.07 / 0.39 / 1.91 | 0.0 | 0.0
--------
x: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x: 0.02 +- 0.53 | -0.61 / 0.0 / 0.81 | -0.0 | -0.0
--------
[7]:
from jaxns import resample
samples = resample(
key=jax.random.PRNGKey(0),
samples=results.samples,
log_weights=results.log_dp_mean,
S=1000,
replace=True
)
[8]:
from jaxns.experimental import EvidenceMaximisation
# Let's train the sigma parameter to maximise the evidence
em = EvidenceMaximisation(model)
results, params = em.train(num_steps=5)
summary(results, with_parametrised=True)
plot_cornerplot(results, with_parametrised=True)
Convergence achieved at step 2, due to log_Z 0.027936573460494607 < log_Z_ftol * log_Z_uncert 0.11888251218887814.: 40%|███████████████████████████▏ | 2/5 [00:10<00:15, 5.30s/it]
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 222600
samples: 2205
phantom samples: 0
likelihood evals / sample: 101.0
phantom fraction (%): 0.0%
--------
logZ=-1.23 +- 0.12
max(logL)=16.85
H=-1.02
ESS=169
--------
sigma: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
sigma: 0.7 +- 0.0 | 0.7 / 0.7 / 0.7 | 0.7 | 0.7
--------
uncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
uncert: 0.6 +- 0.65 | 0.04 / 0.37 / 1.62 | 0.0 | 0.0
--------
x: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x: 0.1 +- 1.8 | -1.6 / -0.0 / 1.8 | 0.5 | 0.6
--------
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:199: UserWarning: Attempting to set identical low and high xlims makes transformation singular; automatically expanding.
ax.set_xlim(per_1, per_99)
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:247: UserWarning: Attempting to set identical low and high xlims makes transformation singular; automatically expanding.
ax.set_xlim(param_limits[parameters[col]])