Poisson likelihood and Gamma prior

This is a simple model where our discrete data, \(y\), is modelled as a Poisson RV with Gamma prior, which is a conjugate prior model.

\(L(x) = p(y | x) = \mathcal{P}[y \mid x]\)

and

\(p(x) = \Gamma[x \mid k, \theta]\).

The analytic evidence for this model is,

\(Z = p(y) = \int_\mathcal{X} L(x) p(x) \,\mathrm{d} x = \mathcal{P}[y \mid 1] \frac{\Gamma[1 \mid k, \theta]}{\Gamma[1 \mid k', \theta']}\)

The posterior is also a Gamma distribution,

\(p(x \mid y) = \Gamma[x \mid k', \theta']\)

where

\(k' = k + \sum_i y_i\)

and

\(\theta' = \frac{\theta}{(\theta \sum_i y_i + 1)}\)

[1]:

import numpy as np import pylab as plt import tensorflow_probability.substrates.jax as tfp from jax import random, numpy as jnp from jaxns import resample tfpd = tfp.distributions
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/internals/mixed_precision.py:14: UserWarning: JAX x64 is not enabled. Setting it now. Check for errors.
  warnings.warn("JAX x64 is not enabled. Setting it now. Check for errors.")
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2]:
# Generate data
np.random.seed(42)

num_samples = 10

true_k = 0.5
true_theta = 1.
_gamma = np.random.gamma(true_k, true_theta, size=num_samples)
print(f"Taking {num_samples} samples from a Poisson distribution as data.")
data = jnp.asarray(np.random.poisson(_gamma, size=num_samples))
Taking 10 samples from a Poisson distribution as data.
[3]:

from jaxns.internals.mixed_precision import mp_policy from jaxns import Prior, Model # Build model prior_k = 100. # Note if prior_theta is chosen too large 32-bit will be insufficient prior_theta = 0.1 def prior_model(): lamda = yield Prior( tfpd.Gamma(concentration=jnp.asarray(prior_k, mp_policy.measure_dtype), rate=1. / jnp.asarray(prior_theta, mp_policy.measure_dtype)), name='lamda') return lamda def log_likelihood(lamda): """ Poisson likelihood. """ _log_prob = jnp.sum(tfpd.Poisson(rate=lamda).log_prob(data)) return _log_prob model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(random.PRNGKey(0), S=100)
INFO:jaxns:Sanity check...
INFO:jaxns:Sanity check passed
[4]:
# Evidence and posterior are analytic
def log_gamma_prob(lamda, k, theta):
    return tfpd.Gamma(concentration=k, rate=1. / theta).log_prob(lamda)
    # return (k-1) * jnp.log(gamma)  - gamma / theta - gammaln(k) - k * jnp.log(theta)


true_post_k = prior_k + jnp.sum(data)
true_post_theta = prior_theta / (num_samples * prior_theta + 1.)

true_post_mean_gamma = true_post_theta * true_post_k

true_logZ = log_likelihood(1.) + log_gamma_prob(1., prior_k, prior_theta) - log_gamma_prob(1., true_post_k,
                                                                                           true_post_theta)
print(f"True Evidence = {true_logZ}")
print(f"True posterior concentration (k) = {true_post_k}")
print(f"True posterior rate (1/theta) = {1. / true_post_theta}")
print(f"True posterior lamda = {true_post_mean_gamma}")

True Evidence = -69.31472389012055
True posterior concentration (k) = 100.0
True posterior rate (1/theta) = 20.0
True posterior lamda = 5.0
[5]:
from jaxns import NestedSampler

# Run the nested sampling
ns = NestedSampler(model=model, verbose=False, init_efficiency_threshold=0.)
[6]:
term_reason, state = ns(random.PRNGKey(3452345))
results = ns.to_results(termination_reason=term_reason, state=state)
# ns.plot_diagnostics(results)
ns.summary(results)
# ns.plot_cornerplot(results)
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 135128
samples: 1260
phantom samples: 0
likelihood evals / sample: 107.2
phantom fraction (%): 0.0%
--------
logZ=-68.26 +- 0.89
max(logL)=-27.77
H=-20.08
ESS=72
--------
lamda: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
lamda: 4.74 +- 0.5 | 4.05 / 4.79 / 5.33 | 4.95 | 2.78
--------
[7]:
ns.plot_diagnostics(results)

/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:45: UserWarning: Found samples with zero likelihood evaluations.
  warnings.warn("Found samples with zero likelihood evaluations.")
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:49: RuntimeWarning: divide by zero encountered in divide
  1. / num_likelihood_evaluations_per_sample
../_images/examples_gamma_poission_7_1.png
[8]:
# Comparing samples to true posterior

samples = resample(random.PRNGKey(43083245), results.samples, results.log_dp_mean, S=1000)

plt.hist(samples['lamda'], bins='auto', ec='blue', alpha=0.5, density=True, fc='none')

_gamma = np.random.gamma(true_post_k, true_post_theta, size=100000)

plt.hist(_gamma, bins='auto', ec='orange', alpha=0.5, density=True, fc='none')
plt.show()
../_images/examples_gamma_poission_8_0.png
[12]: