Multivariate Normal Likelihood with Multivariate Normal Prior

This is a simple model where our data, \(y\), is modelled as a multivariate normal RV with uncorrelated noise.

\(L(x) = p(y | x) = \mathcal{N}[y \mid x,\Sigma]\)

and

\(p(x) = \mathcal{N}[x \mid \mu, \sigma^2 \mathbf{I}]\).

The analytic evidence for this model is,

\(Z = p(y) = \mathcal{N}[y \mid \mu, \Sigma + \sigma^2 \mathbf{I}]\)

The posterior is also a multivariate normal distribution,

\(p(x \mid y) = \mathcal{N}[\mu', \Sigma']\)

where

\(\mu' = \sigma^2 \mathbf{I} (\sigma^2 \mathbf{I} + \Sigma)^{-1} y + \Sigma ( \sigma^2 \mathbf{I} + \Sigma)^{-1} \mu\)

and

\(\Sigma' = \sigma^2 \mathbf{I} (\sigma^2 \mathbf{I} + \Sigma)^{-1} \Sigma\)

[6]:

import tensorflow_probability.substrates.jax as tfp from jax import random, numpy as jnp from jaxns import DefaultNestedSampler from jaxns import Model from jaxns import Prior tfpd = tfp.distributions
[7]:
from jax._src.scipy.linalg import solve_triangular


def log_normal(x, mean, cov):
    L = jnp.linalg.cholesky(cov)
    dx = x - mean
    dx = solve_triangular(L, dx, lower=True)
    return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) - 0.5 * dx @ dx


# define our data and prior
ndims = 16
prior_mu = 15 * jnp.ones(ndims)
prior_cov = jnp.diag(jnp.ones(ndims)) ** 2

data_mu = jnp.zeros(ndims)
data_cov = jnp.diag(jnp.ones(ndims)) ** 2
data_cov = jnp.where(data_cov == 0., 0.99, data_cov)

true_logZ = log_normal(data_mu, prior_mu, prior_cov + data_cov)

J = jnp.linalg.solve(data_cov + prior_cov, prior_cov)
post_mu = prior_mu + J.T @ (data_mu - prior_mu)
post_cov = prior_cov - J.T @ (prior_cov + data_cov) @ J

print("True logZ={}".format(true_logZ))
print("True post_mu={}".format(post_mu))
print("True post_cov={}".format(post_cov))

True logZ=-123.01473999023438
True post_mu=[14.109793 14.109791 14.109794 14.109795 14.109792 14.109793 14.109793
 14.109791 14.109793 14.109794 14.109793 14.109793 14.109792 14.109793
 14.109794 14.109793]
True post_cov=[[0.0680728  0.05817217 0.05817196 0.05817205 0.0581721  0.05817203
  0.05817202 0.05817205 0.05817201 0.05817205 0.05817202 0.05817197
  0.05817199 0.05817199 0.05817198 0.05817195]
 [0.05817217 0.0680719  0.05817202 0.05817205 0.0581721  0.05817205
  0.05817205 0.05817203 0.05817202 0.05817201 0.05817203 0.05817198
  0.05817199 0.05817199 0.05817192 0.05817193]
 [0.05817197 0.058172   0.0680728  0.05817198 0.05817212 0.05817208
  0.05817208 0.05817204 0.05817202 0.05817202 0.05817204 0.05817203
  0.05817201 0.05817202 0.05817195 0.05817201]
 [0.05817204 0.05817205 0.05817197 0.06807292 0.05817207 0.05817203
  0.05817204 0.05817204 0.058172   0.05817201 0.05817201 0.05817201
  0.05817199 0.058172   0.05817195 0.05817197]
 [0.05817207 0.05817202 0.05817207 0.05817201 0.0680728  0.05817198
  0.05817204 0.05817199 0.05817201 0.05817199 0.058172   0.05817199
  0.05817199 0.05817197 0.05817192 0.05817197]
 [0.05817201 0.05817201 0.05817201 0.05817201 0.05817199 0.06807286
  0.05817198 0.05817198 0.05817201 0.05817196 0.05817197 0.05817199
  0.05817199 0.05817204 0.05817201 0.05817203]
 [0.05817202 0.05817204 0.05817203 0.05817202 0.05817201 0.05817199
  0.06807268 0.058172   0.05817202 0.05817201 0.05817202 0.05817202
  0.05817201 0.05817199 0.05817202 0.05817203]
 [0.05817199 0.05817197 0.058172   0.05817199 0.05817197 0.05817196
  0.05817199 0.06807274 0.05817201 0.05817199 0.058172   0.05817201
  0.05817202 0.05817201 0.05817195 0.05817209]
 [0.05817201 0.05817202 0.05817202 0.05817201 0.05817201 0.05817199
  0.05817201 0.05817201 0.06807262 0.05817202 0.058172   0.05817202
  0.05817202 0.05817198 0.05817199 0.05817198]
 [0.05817204 0.05817198 0.05817198 0.05817198 0.05817195 0.05817194
  0.05817201 0.05817202 0.05817204 0.06807262 0.05817206 0.05817202
  0.05817204 0.05817206 0.05817207 0.05817204]
 [0.05817202 0.05817202 0.05817202 0.05817202 0.05817199 0.05817198
  0.05817199 0.05817205 0.05817201 0.05817201 0.06807238 0.05817203
  0.05817202 0.05817197 0.05817206 0.05817203]
 [0.05817198 0.05817198 0.05817204 0.05817204 0.05817195 0.058172
  0.05817201 0.05817201 0.05817197 0.05817202 0.05817204 0.06807262
  0.05817205 0.05817199 0.05817201 0.05817202]
 [0.05817201 0.05817201 0.05817201 0.05817201 0.05817199 0.05817198
  0.05817199 0.05817205 0.058172   0.058172   0.05817201 0.05817201
  0.06807274 0.05817195 0.05817196 0.05817193]
 [0.058172   0.05817201 0.05817201 0.058172   0.05817199 0.05817198
  0.05817198 0.05817198 0.05817197 0.05817202 0.05817196 0.05817197
  0.05817196 0.0680728  0.05817205 0.058172  ]
 [0.05817193 0.05817192 0.05817194 0.05817193 0.0581719  0.05817197
  0.05817203 0.05817195 0.05817194 0.05817207 0.05817208 0.058172
  0.058172   0.05817205 0.06807327 0.05817194]
 [0.05817197 0.05817197 0.05817205 0.05817202 0.05817202 0.05817202
  0.05817202 0.05817211 0.05817195 0.05817201 0.05817202 0.05817199
  0.05817194 0.05817197 0.05817189 0.06807297]]
[8]:



def prior_model(): x = yield Prior(tfpd.MultivariateNormalTriL(loc=prior_mu, scale_tril=jnp.linalg.cholesky(prior_cov)), name='x') return x # The likelihood is a callable that will take def log_likelihood(x): return log_normal(x, data_mu, data_cov) model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
[9]:
import jax

# Create the nested sampler class. In this case without any tuning.
ns = DefaultNestedSampler(
    model=model,
    max_samples=1e6,
    parameter_estimation=True,
    verbose=True)

termination_reason, state = jax.jit(ns)(random.PRNGKey(42654))
results = ns.to_results(termination_reason=termination_reason, state=state)

# We can always save results to play with later
ns.save_results(results, 'save.npz')
# loads previous results by uncommenting below
# results = load_results('save.npz')


-------Num samples: 8640
Num likelihood evals: 5539
Efficiency: 1.4731900691986084log(L) contour: -728.3695068359375
log(Z) est.: -300.04827880859375 +- 0.8311780691146851
-------Num samples: 16800
Num likelihood evals: 20408
Efficiency: 0.3998431861400604log(L) contour: -544.81689453125
log(Z) est.: -237.0644073486328 +- 0.8337075710296631
-------Num samples: 24960
Num likelihood evals: 42750
Efficiency: 0.19087719917297363log(L) contour: -449.32080078125
log(Z) est.: -212.8740997314453 +- 0.8348050117492676
-------Num samples: 33120
Num likelihood evals: 71617
Efficiency: 0.11393942683935165log(L) contour: -386.14788818359375
log(Z) est.: -209.5166778564453 +- 0.8285486698150635
-------Num samples: 41280
Num likelihood evals: 106851
Efficiency: 0.07636802643537521log(L) contour: -334.6083984375
log(Z) est.: -192.8569793701172 +- 0.8216135501861572
-------Num samples: 49440
Num likelihood evals: 146926
Efficiency: 0.05553816258907318log(L) contour: -294.7420349121094
log(Z) est.: -189.3365478515625 +- 0.6395522356033325
-------Num samples: 57600
Num likelihood evals: 193304
Efficiency: 0.04221330210566521log(L) contour: -262.7485656738281
log(Z) est.: -174.17044067382812 +- 0.8403432369232178
-------Num samples: 65760
Num likelihood evals: 244230
Efficiency: 0.03341113030910492log(L) contour: -237.11953735351562
log(Z) est.: -166.09866333007812 +- 0.8416314721107483
-------Num samples: 73920
Num likelihood evals: 300774
Efficiency: 0.02713000401854515log(L) contour: -217.0238037109375
log(Z) est.: -161.7726287841797 +- 0.8388165831565857
-------Num samples: 82080
Num likelihood evals: 361248
Efficiency: 0.02258836105465889log(L) contour: -200.4063262939453
log(Z) est.: -154.78562927246094 +- 0.8359375
-------Num samples: 90240
Num likelihood evals: 426303
Efficiency: 0.019141314551234245log(L) contour: -185.8770294189453
log(Z) est.: -155.72647094726562 +- 0.8175548911094666
-------Num samples: 98400
Num likelihood evals: 496992
Efficiency: 0.01641877554357052log(L) contour: -172.7545166015625
log(Z) est.: -143.5293731689453 +- 0.8041374087333679
-------Num samples: 106560
Num likelihood evals: 571878
Efficiency: 0.014268777333199978log(L) contour: -163.04940795898438
log(Z) est.: -144.02284240722656 +- 0.6915496587753296
-------Num samples: 114720
Num likelihood evals: 649659
Efficiency: 0.012560435570776463log(L) contour: -153.49940490722656
log(Z) est.: -138.56869506835938 +- 0.8307924270629883
-------Num samples: 122880
Num likelihood evals: 732239
Efficiency: 0.011143902316689491log(L) contour: -145.36865234375
log(Z) est.: -128.83413696289062 +- 0.8489782810211182
-------Num samples: 131040
Num likelihood evals: 818576
Efficiency: 0.009968530386686325log(L) contour: -138.55641174316406
log(Z) est.: -129.83192443847656 +- 0.8493376970291138
-------Num samples: 139200
Num likelihood evals: 907776
Efficiency: 0.008989001624286175log(L) contour: -132.45956420898438
log(Z) est.: -130.25645446777344 +- 0.6876775622367859
-------Num samples: 147360
Num likelihood evals: 1002369
Efficiency: 0.008140714839100838log(L) contour: -127.15589904785156
log(Z) est.: -130.20388793945312 +- 0.6417434811592102
-------Num samples: 155520
Num likelihood evals: 1098900
Efficiency: 0.007425607647746801log(L) contour: -122.57423400878906
log(Z) est.: -130.49533081054688 +- 0.45410051941871643
-------Num samples: 163680
Num likelihood evals: 1199549
Efficiency: 0.006802556570619345log(L) contour: -118.76657104492188
log(Z) est.: -130.69024658203125 +- 0.34960392117500305
-------Num samples: 171840
Num likelihood evals: 1303902
Efficiency: 0.006258138921111822log(L) contour: -115.35493469238281
log(Z) est.: -127.0222396850586 +- 0.6918805837631226
-------Num samples: 180000
Num likelihood evals: 1410543
Efficiency: 0.00578500609844923log(L) contour: -112.29899597167969
log(Z) est.: -125.68201446533203 +- 0.7894684672355652
-------Num samples: 188160
Num likelihood evals: 1520896
Efficiency: 0.005365258548408747log(L) contour: -109.60888671875
log(Z) est.: -126.09242248535156 +- 0.5824767351150513
-------Num samples: 196320
Num likelihood evals: 1632886
Efficiency: 0.004997286945581436log(L) contour: -107.22428131103516
log(Z) est.: -125.55574035644531 +- 0.5379148721694946
-------Num samples: 204480
Num likelihood evals: 1746480
Efficiency: 0.004672254901379347log(L) contour: -104.97039031982422
log(Z) est.: -125.09891510009766 +- 0.5154030323028564
-------Num samples: 212640
Num likelihood evals: 1862762
Efficiency: 0.004380591679364443log(L) contour: -103.10183715820312
log(Z) est.: -124.87657928466797 +- 0.41681718826293945
-------Num samples: 220800
Num likelihood evals: 1979359
Efficiency: 0.004122546873986721log(L) contour: -101.38316345214844
log(Z) est.: -123.85633087158203 +- 0.6234355568885803
-------Num samples: 228960
Num likelihood evals: 2096536
Efficiency: 0.0038921344093978405log(L) contour: -99.70223236083984
log(Z) est.: -124.14187622070312 +- 0.462375283241272
-------Num samples: 237120
Num likelihood evals: 2214044
Efficiency: 0.0036855635698884726log(L) contour: -98.15310668945312
log(Z) est.: -124.15444946289062 +- 0.31476232409477234
-------Num samples: 245280
Num likelihood evals: 2334275
Efficiency: 0.0034957320895045996log(L) contour: -96.7013168334961
log(Z) est.: -123.77859497070312 +- 0.3388094902038574
-------Num samples: 253440
Num likelihood evals: 2452902
Efficiency: 0.0033266718965023756log(L) contour: -95.28227233886719
log(Z) est.: -123.36224365234375 +- 0.339214563369751
-------Num samples: 261600
Num likelihood evals: 2573192
Efficiency: 0.0031711587216705084log(L) contour: -93.95650482177734
log(Z) est.: -123.49649810791016 +- 0.2931770086288452
-------Num samples: 269760
Num likelihood evals: 2691911
Efficiency: 0.003031303873285651log(L) contour: -92.67212677001953
log(Z) est.: -122.19329071044922 +- 0.5952386856079102
-------Num samples: 277920
Num likelihood evals: 2811629
Efficiency: 0.0029022321105003357log(L) contour: -91.57221221923828
log(Z) est.: -122.40362548828125 +- 0.4291722774505615
-------Num samples: 286080
Num likelihood evals: 2930116
Efficiency: 0.0027848726604133844log(L) contour: -90.44287109375
log(Z) est.: -122.88108825683594 +- 0.3409867584705353
-------Num samples: 294240
Num likelihood evals: 3050131
Efficiency: 0.002675294876098633log(L) contour: -89.39148712158203
log(Z) est.: -123.00808715820312 +- 0.28829607367515564
-------Num samples: 302400
Num likelihood evals: 3172103
Efficiency: 0.002572425873950124log(L) contour: -88.40502166748047
log(Z) est.: -123.05693817138672 +- 0.279508501291275
-------Num samples: 310560
Num likelihood evals: 3295072
Efficiency: 0.002476425375789404log(L) contour: -87.44629669189453
log(Z) est.: -123.07344055175781 +- 0.27934467792510986
-------Num samples: 318720
Num likelihood evals: 3415567
Efficiency: 0.0023890614975243807log(L) contour: -86.577392578125
log(Z) est.: -123.08255767822266 +- 0.28211671113967896
-------Num samples: 326880
Num likelihood evals: 3537034
Efficiency: 0.0023070177994668484log(L) contour: -85.78118896484375
log(Z) est.: -123.05679321289062 +- 0.28561073541641235
-------Num samples: 335040
Num likelihood evals: 3660368
Efficiency: 0.002229284029453993log(L) contour: -85.01358032226562
log(Z) est.: -123.058837890625 +- 0.28643763065338135
-------Num samples: 343200
Num likelihood evals: 3781957
Efficiency: 0.002157613169401884log(L) contour: -84.23200988769531
log(Z) est.: -123.0244369506836 +- 0.29009002447128296
-------Num samples: 351360
Num likelihood evals: 3903237
Efficiency: 0.00209057261236012log(L) contour: -83.52565002441406
log(Z) est.: -123.03254699707031 +- 0.2899058759212494
-------Num samples: 359520
Num likelihood evals: 4023162
Efficiency: 0.0020282554905861616log(L) contour: -82.84107971191406
log(Z) est.: -123.03301239013672 +- 0.29095664620399475
-------Num samples: 367680
Num likelihood evals: 4143905
Efficiency: 0.0019691570196300745log(L) contour: -82.16433715820312
log(Z) est.: -123.03369903564453 +- 0.2903791666030884
-------Num samples: 375840
Num likelihood evals: 4265102
Efficiency: 0.0019132016459479928log(L) contour: -81.48123931884766
log(Z) est.: -123.03260040283203 +- 0.2921864688396454
-------Num samples: 384000
Num likelihood evals: 4383634
Efficiency: 0.0018614692380651832log(L) contour: -80.79535675048828
log(Z) est.: -123.03145599365234 +- 0.2928124666213989
-------Num samples: 392160
Num likelihood evals: 4503757
Efficiency: 0.0018118206644430757log(L) contour: -80.09120178222656
log(Z) est.: -123.02919006347656 +- 0.2918206751346588
-------Num samples: 400320
Num likelihood evals: 4624488
Efficiency: 0.0017645196057856083log(L) contour: -79.43296813964844
log(Z) est.: -123.02909088134766 +- 0.2932030260562897
-------Num samples: 408480
Num likelihood evals: 4743437
Efficiency: 0.0017202716553583741log(L) contour: -78.82025909423828
log(Z) est.: -123.02926635742188 +- 0.2924213707447052
-------Num samples: 416640
Num likelihood evals: 4861215
Efficiency: 0.001678592641837895log(L) contour: -78.1926498413086
log(Z) est.: -123.0274429321289 +- 0.2920036315917969
-------Num samples: 424800
Num likelihood evals: 4981869
Efficiency: 0.0016379394801333547log(L) contour: -77.64949035644531
log(Z) est.: -123.02936553955078 +- 0.2935151159763336
-------Num samples: 432960
Num likelihood evals: 5103830
Efficiency: 0.001598799368366599log(L) contour: -77.11564636230469
log(Z) est.: -123.03079223632812 +- 0.2932030260562897
[10]:
# We can use the summary utility to display results
ns.summary(results)
# We plot useful diagnostics and a distribution cornerplot
ns.plot_diagnostics(results)
ns.plot_cornerplot(results)

--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 5104310
samples: 432960
phantom samples: 407040
likelihood evals / sample: 11.8
phantom fraction (%): 94.0%
--------
logZ=-122.661 +- 0.0014
H=-33.66
ESS=4029
--------
x[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
x[0]: 14.09 +- 0.26 | 13.77 / 14.08 / 14.42 | 14.17 | 13.22
x[1]: 14.09 +- 0.25 | 13.77 / 14.08 / 14.41 | 14.22 | 13.17
x[2]: 14.09 +- 0.26 | 13.77 / 14.08 / 14.42 | 14.17 | 13.28
x[3]: 14.09 +- 0.26 | 13.78 / 14.08 / 14.43 | 14.16 | 13.05
x[4]: 14.09 +- 0.25 | 13.78 / 14.08 / 14.41 | 14.2 | 13.11
x[5]: 14.09 +- 0.25 | 13.79 / 14.08 / 14.41 | 14.24 | 13.15
x[6]: 14.09 +- 0.25 | 13.78 / 14.08 / 14.42 | 14.23 | 13.09
x[7]: 14.09 +- 0.26 | 13.77 / 14.08 / 14.43 | 14.22 | 13.24
x[8]: 14.09 +- 0.26 | 13.77 / 14.08 / 14.42 | 14.16 | 13.31
x[9]: 14.09 +- 0.25 | 13.78 / 14.08 / 14.42 | 14.23 | 13.1
x[10]: 14.09 +- 0.26 | 13.78 / 14.08 / 14.42 | 14.2 | 13.17
x[11]: 14.09 +- 0.26 | 13.77 / 14.08 / 14.43 | 14.24 | 13.1
x[12]: 14.09 +- 0.26 | 13.77 / 14.09 / 14.42 | 14.2 | 13.21
x[13]: 14.09 +- 0.25 | 13.78 / 14.08 / 14.42 | 14.19 | 13.1
x[14]: 14.09 +- 0.26 | 13.77 / 14.08 / 14.42 | 14.16 | 13.15
x[15]: 14.09 +- 0.26 | 13.78 / 14.09 / 14.43 | 14.27 | 13.18
--------
/home/albert/git/jaxns/jaxns/plotting.py:47: RuntimeWarning: divide by zero encountered in divide
  efficiency = 1. / num_likelihood_evaluations_per_sample
../_images/examples_mvn_data_mvn_prior_5_2.png
../_images/examples_mvn_data_mvn_prior_5_3.png