Inference of Jones scalars observables (noisy angular quantities)

This is a simple physics model where our data, \(y\), is modelled as the principle argument of a unitary complexy vector with white noise added, i.e. the phase of a complex RV, \(\phi_\nu = K \tau \nu^{-1} + M \eta \nu + \epsilon\)

\(L(x) = p(y | x) = \mathcal{N}[y \mid \phi_{\rm obs},\sigma^2 \mathbf{I}]\)

where \(\phi_{\rm obs} = \arg Y\) and \(Y \sim \mathcal{N}(e^{i \phi}, \sigma^2 \mathbf{I}_{\mathbb{C}})\)

and we take the priors,

\(p(\tau) = \mathcal{U}[\tau \mid -300, 300]\) (Uniform)

\(p(\eta) = \mathcal{U}[\eta \mid -2, 2]\) (Uniform)

\(p(\epsilon) = \mathcal{U}[\epsilon \mid -\pi, \pi]\) (Uniform)

\(p(\sigma) = \mathcal{HN}[\sigma \mid 0.5]\) (Half-Normal)

This in an example of a problem where the maximum aposteriori (MAP) estimate is not a good estimate of the ground truth. In this case, you cannot use a pointwise estimate to estimate the phase, and you cannot use the phases in a physical model (because they are biased).

[22]:


import pylab as plt import tensorflow_probability.substrates.jax as tfp from jax import random, numpy as jnp from jaxns import NestedSampler from jaxns import bruteforce_evidence tfpd = tfp.distributions
[23]:

TEC_CONV = -8.4479745 #rad*MHz/mTECU CLOCK_CONV = (2e-3 * jnp.pi) #rad/MHz/ns def wrap(phi): return (phi + jnp.pi) % (2 * jnp.pi) - jnp.pi def generate_data(key, uncert): """ Generate gain data where the phase have a clock const and tec component. This is a model of the impact of the ionosphere on the propagation of radio waves, part of radio interferometry: phase[:] = tec * (tec_conv / freqs[:]) + clock * (clock_conv * freqs[:]) + const then the gains are: gains[:] ~ Normal[{cos(phase[:]), sin(phase[:])}, uncert^2 * I] phase_obs[:] = ArcTan[gains.imag, gains.real] Args: key: uncert: uncertainty of the gains Returns: phase_obs, freqs """ freqs = jnp.linspace(110, 150, 24) #MHz tec = 90. #mTECU const = 2. #rad clock = 0.5 #ns phase = tec * (TEC_CONV / freqs) + clock * (CLOCK_CONV * freqs) + const Y = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1) Y_obs = Y + uncert * random.normal(key, shape=Y.shape) phase_obs = jnp.arctan2(Y_obs[..., freqs.size:], Y_obs[..., :freqs.size]) return Y_obs, phase, phase_obs, freqs
[24]:
# Generate data
key = random.PRNGKey(43)
key, data_key = random.split(key)
Y_obs, phase_underlying, phase_obs, freqs = generate_data(data_key, 5./57.)
plt.scatter(freqs, phase_obs, label='data')
plt.plot(freqs, phase_underlying, label='Underlying phase')
plt.legend()
plt.show()
# Note: the phase wrapping makes this a difficult problem to solve. As we'll see, the posterior is rather complicated.
../_images/examples_Jones_scalar_modelling_3_0.png
[25]:

from jaxns import Prior, Model def log_normal(x, mean, scale): return tfpd.Normal(mean, scale).log_prob(x) def log_likelihood(dtec, const, clock, uncert): phase = dtec * (TEC_CONV / freqs) + const + clock * (CLOCK_CONV * freqs) mean = jnp.concatenate([jnp.cos(phase), jnp.sin(phase)], axis=-1) return jnp.sum(log_normal(Y_obs, mean, uncert)) def prior_model(): tec = yield Prior(tfpd.Uniform(-300, 300.), name='dtec') const = yield Prior(tfpd.Uniform(-jnp.pi, jnp.pi), name='const') clock = yield Prior(tfpd.Uniform(-2., 2.), name='clock') uncert = yield Prior(tfpd.HalfNormal(0.25), name='uncert') return tec, const, clock, uncert model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(random.PRNGKey(0), S=100) # log_Z_true = bruteforce_evidence(model=model, S=80) # print(f"Approx. log(Z)={log_Z_true}") # Unsure if this grid is sufficient to get a good estimate of the evidence.
INFO:jaxns:Sanity check...
INFO:jaxns:Sanity check passed
[26]:
import jax

# Create the nested sampler class. In this case without any tuning.
ns = NestedSampler(model=model, s=10, k=model.U_ndims, num_live_points=model.U_ndims * 1000)
termination_reason, state = jax.jit(ns)(random.PRNGKey(432345987))
results = ns.to_results(termination_reason=termination_reason, state=state)

INFO:jaxns:Number of Markov-chains set to: 800
[27]:
# We can use the summary utility to display results
ns.summary(results)
--------
Termination Conditions:
Small remaining evidence
--------
likelihood evals: 3062437
samples: 56800
phantom samples: 44800
likelihood evals / sample: 53.9
phantom fraction (%): 78.9%
--------
logZ=-13.2 +- 0.14
max(logL)=-4.06
H=-7.61
ESS=1407
--------
clock: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
clock: 0.1 +- 1.2 | -1.5 / 0.0 / 1.6 | -1.6 | -1.7
--------
const: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
const: 0.2 +- 1.7 | -2.1 / 0.3 / 2.6 | -2.3 | -2.2
--------
dtec: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
dtec: 73.0 +- 16.0 | 51.0 / 73.0 / 95.0 | 94.0 | 95.0
--------
uncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
uncert: 0.272 +- 0.028 | 0.239 / 0.27 / 0.307 | 0.261 | 0.264
--------
[28]:
# Finally let's look at the results.
ns.plot_diagnostics(results)
ns.plot_cornerplot(results)
# ns.plot_cornerplot(results, save_name='jones_corner.png', kde_overlay=True)

# We can see that the sampler focused more on the initial part of the enclosed prior volume when -logX < 7.5.
# This is evident in the increased n_live points.
/home/albert/git/jaxns/src/jaxns/plotting.py:45: UserWarning: Found samples with zero likelihood evaluations.
  warnings.warn("Found samples with zero likelihood evaluations.")
../_images/examples_Jones_scalar_modelling_7_1.png
../_images/examples_Jones_scalar_modelling_7_2.png
[ ]: