Logic rules

Given a state of knowledge encoded in terms of plausibility statements, e.g. \(P(S) = p_{S}\) or \(P(S) > \mu_{S}\), where \(S\) is any logic sentence, infer the posterior over the logical propositions which make up \(S\).

In this example we’ll let there be three boolean variables, \(A\), \(B\), and \(C\), and we’ll define our state of knowleddge to be:

  • \(P(A \implies ( B \iff C)) = 0.9\), and

  • \(0.6 < P(B) < 0.9\)

[10]:
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp

tfpd = tfp.distributions
[11]:

from jaxns import Prior, Model, Bernoulli num_predicates = 3 def prior_model(): p = yield Prior(tfpd.Uniform(jnp.zeros(num_predicates), jnp.ones(num_predicates)), name='p') predicates = yield Bernoulli(probs=p) return predicates, p[1] def log_likelihood(predicates, p_B): """ state of knowledge P(a=>(b<=>c)) = p_rule & 0.6 <= P(b) <= 0.9 """ predicates = predicates.astype(jnp.bool_) a = predicates[0] b = predicates[1] c = predicates[2] b_imp_c = c | ~b c_imp_b = b | ~c imp_imp = (b_imp_c & c_imp_b) | ~a #0.6 <= P(b) <= 0.9 log_prob_1 = jnp.where(imp_imp, jnp.log(0.9), jnp.log(1. - 0.9)) log_prob_2 = jnp.where((p_B < 0.6) | (p_B > 0.9), -jnp.inf, 0.) log_prop = log_prob_1 + log_prob_2 return log_prop model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(random.PRNGKey(0), S=100)
INFO[2024-01-10 00:53:07,631]: Sanity check...
INFO[2024-01-10 00:53:07,638]: Sanity check passed
[12]:
import jax
from jaxns import DefaultNestedSampler

# Run the nested sampling
ns = DefaultNestedSampler(model=model, max_samples=1e5, c=1000, difficult_model=True, parameter_estimation=True)

term_reason, state = jax.jit(ns)(random.PRNGKey(3452345))
results = ns.to_results(termination_reason=term_reason, state=state)
[13]:
# Inspect results
ns.summary(results)
ns.plot_diagnostics(results)
ns.plot_cornerplot(results)
--------
Termination Conditions:
Small remaining evidence
All live-points are on a single plateau (potential numerical errors, consider 64-bit)
--------
likelihood evals: 7340
samples: 2000
phantom samples: 0
likelihood evals / sample: 3.7
phantom fraction (%): 0.0%
--------
logZ=-0.386 +- 0.015
H=-0.19
ESS=541
--------
p[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
p[0]: 0.46 +- 0.28 | 0.08 / 0.42 / 0.89 | 0.99 | 0.22
p[1]: 0.748 +- 0.088 | 0.629 / 0.749 / 0.871 | 0.862 | 0.871
p[2]: 0.54 +- 0.3 | 0.1 / 0.55 / 0.93 | 0.99 | 0.93
--------
../_images/examples_plausible_logic_4_1.png
../_images/examples_plausible_logic_4_2.png
[13]:

[13]: