Using JAXNS to globally optimise Neural Networks
Neural network training is typically done with maximum likelihood estimation. Given the number of parameter invariances in neural network architectures, this often introduces a large number of local minima, making global optimisation very difficult.
JAXNS can easily navigate complex degeneracies during it’s calculation of the Bayesian evidence. This yields a powerful tool for globally optimising neural networks.
What we’ll do in this notebook
Define a neural network model
Find its maximum likelihood parameters using JAXNS
Data
We’ll use the N-bit majority problem as our data. This is a binary classification problem where the input is a sequence of bits and the output is 1 if the majority of the bits are 1, and 0 otherwise. It is known that the \(n\) bit problem requires at least \(n\) hidden units to solve. We will show that the global optimum is found by JAXNS with \(n\) hidden units, and not with \(n-1\).
[1]:
import os
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
try:
import haiku as hk
except ImportError:
print("You must `pip install dm-haiku` first.")
raise
import jax
import tensorflow_probability.substrates.jax as tfp
from jax import random, numpy as jnp
import numpy as np
import itertools
np.random.seed(42)
tfpd = tfp.distributions
[2]:
# Generate data
def generate_parity_dataset(N):
"""
Generates a dataset for the N-bit parity problem.
Args:
N (int): Number of bits in each sample.
Returns:
X (np.ndarray): Input features of shape (num_samples, N).
y (np.ndarray): Labels of shape (num_samples,).
"""
# Generate random binary inputs
X = np.asarray(list(itertools.product([0, 1], repeat=N)))
# Compute parity (even or odd number of ones)
y = np.mod(np.sum(X, axis=1), 2)
return X.astype(np.float32), y.astype(np.float32)
num_variables = 5
x, y = generate_parity_dataset(num_variables)
print("Data:")
for input, output in zip(x, y):
print(f"{input} -> {output}")
Data:
[0. 0. 0. 0. 0.] -> 0.0
[0. 0. 0. 0. 1.] -> 1.0
[0. 0. 0. 1. 0.] -> 1.0
[0. 0. 0. 1. 1.] -> 0.0
[0. 0. 1. 0. 0.] -> 1.0
[0. 0. 1. 0. 1.] -> 0.0
[0. 0. 1. 1. 0.] -> 0.0
[0. 0. 1. 1. 1.] -> 1.0
[0. 1. 0. 0. 0.] -> 1.0
[0. 1. 0. 0. 1.] -> 0.0
[0. 1. 0. 1. 0.] -> 0.0
[0. 1. 0. 1. 1.] -> 1.0
[0. 1. 1. 0. 0.] -> 0.0
[0. 1. 1. 0. 1.] -> 1.0
[0. 1. 1. 1. 0.] -> 1.0
[0. 1. 1. 1. 1.] -> 0.0
[1. 0. 0. 0. 0.] -> 1.0
[1. 0. 0. 0. 1.] -> 0.0
[1. 0. 0. 1. 0.] -> 0.0
[1. 0. 0. 1. 1.] -> 1.0
[1. 0. 1. 0. 0.] -> 0.0
[1. 0. 1. 0. 1.] -> 1.0
[1. 0. 1. 1. 0.] -> 1.0
[1. 0. 1. 1. 1.] -> 0.0
[1. 1. 0. 0. 0.] -> 0.0
[1. 1. 0. 0. 1.] -> 1.0
[1. 1. 0. 1. 0.] -> 1.0
[1. 1. 0. 1. 1.] -> 0.0
[1. 1. 1. 0. 0.] -> 1.0
[1. 1. 1. 0. 1.] -> 0.0
[1. 1. 1. 1. 0.] -> 0.0
[1. 1. 1. 1. 1.] -> 1.0
[3]:
from jaxns.internals.maps import pytree_unravel
from jaxns import Prior, Model
import jaxns.framework.context as ctx
from jaxns import NestedSampler, TerminationCondition
def run(n_hidden_units):
def prior_model():
def compute_logits(x):
mlp = hk.Sequential([
hk.Linear(n_hidden_units),
jax.nn.relu,
hk.Linear(1)
])
return mlp(x)
init, apply = hk.transform(compute_logits)
init_params = init(random.PRNGKey(0), x)
# Convert haiku to jaxns params
ctx_params = ctx.convert_external_params(init_params, prefix='haiku_model')
# Flatten, model, then unflatten to use
ravel_fn, unravel_fn = pytree_unravel(ctx_params)
ndims = ravel_fn(init_params).size
flat_params = yield Prior(tfpd.Uniform(-10. * jnp.ones(ndims), 10. * jnp.ones(ndims)), name='flat_params')
params = unravel_fn(flat_params)
logits = apply(params, jax.random.PRNGKey(0), x)[:, 0] # [n]
return logits.astype(jnp.float32)
def log_likelihood(logits):
# Classification probelm, so we use a Bernoulli likelihood
return tfpd.Bernoulli(logits=logits).log_prob(y).mean()
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(random.PRNGKey(0), S=100)
ns = NestedSampler(model=model)
term_reason, state = jax.jit(ns)(
random.PRNGKey(42),
TerminationCondition(atol=0.01)
)
results = ns.to_results(term_reason, state)
ns.summary(results)
ns.plot_diagnostics(results)
ns.plot_cornerplot(results)
solution = results.U_samples[jnp.argmax(results.log_L_samples)]
logits = model.prepare_input(solution)[0]
predictions = jax.nn.sigmoid(logits)
for i in range(len(y)):
pred = predictions[i]
print(f"{i}: {x[i]} -> {y[i]} | pred: {pred} {'✓' if (pred > 0.5) == y[i] else '✗'}")
accuracy = jnp.mean((predictions > 0.5) == y)
print(f"Accuracy: {accuracy * 100:.1f}%")
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/internals/mixed_precision.py:14: UserWarning: JAX x64 is not enabled. Setting it now. Check for errors.
warnings.warn("JAX x64 is not enabled. Setting it now. Check for errors.")
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda':
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[4]:
run(num_variables)
INFO:jaxns:Sanity check...
INFO:jaxns:Sanity check passed
Running over 12 devices.
--------
Termination Conditions:
absolute spread of live points < atol
--------
likelihood evals: 11301639
samples: 46980
phantom samples: 0
likelihood evals / sample: 240.6
phantom fraction (%): 0.0%
--------
logZ=-7.606 +- 0.078
max(logL)=-0.001
H=-4.46
ESS=2219
--------
flat_params[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
flat_params[0]: -3.2 +- 5.1 | -9.0 / -4.2 / 4.7 | 7.7 | 7.7
flat_params[1]: -3.1 +- 5.1 | -9.1 / -4.0 / 5.0 | -0.6 | -0.6
flat_params[2]: -3.1 +- 5.2 | -8.9 / -4.3 / 4.9 | -6.0 | -6.0
flat_params[3]: -3.3 +- 5.0 | -8.9 / -4.5 / 4.3 | -7.4 | -7.4
flat_params[4]: -2.9 +- 5.2 | -8.9 / -3.9 / 5.0 | -8.9 | -8.9
flat_params[5]: -1.9 +- 5.4 | -8.5 / -2.6 / 6.3 | 9.3 | 9.3
flat_params[6]: -1.7 +- 5.3 | -8.4 / -2.3 / 6.2 | 8.1 | 8.1
flat_params[7]: -1.7 +- 5.4 | -8.4 / -2.3 / 6.4 | -4.1 | -4.1
flat_params[8]: -1.5 +- 5.3 | -8.5 / -2.0 / 6.1 | 4.7 | 4.7
flat_params[9]: -1.3 +- 5.3 | -8.2 / -1.7 / 6.7 | -4.6 | -4.6
flat_params[10]: -1.6 +- 5.5 | -8.7 / -2.4 / 6.6 | -9.2 | -9.2
flat_params[11]: -2.0 +- 5.2 | -8.7 / -2.7 / 5.6 | -8.9 | -8.9
flat_params[12]: -1.7 +- 5.2 | -8.3 / -2.3 / 5.8 | 5.4 | 5.4
flat_params[13]: -1.8 +- 5.3 | -8.6 / -2.5 / 6.0 | -8.6 | -8.6
flat_params[14]: -1.5 +- 5.4 | -8.5 / -1.9 / 6.3 | 5.4 | 5.4
flat_params[15]: -1.4 +- 5.3 | -8.3 / -1.6 / 6.4 | -8.2 | -8.2
flat_params[16]: -1.9 +- 5.1 | -8.4 / -2.4 / 5.4 | -6.7 | -6.7
flat_params[17]: -1.5 +- 5.3 | -8.4 / -2.1 / 6.1 | 5.1 | 5.1
flat_params[18]: -1.5 +- 5.4 | -8.4 / -2.2 / 6.5 | -9.7 | -9.7
flat_params[19]: -1.5 +- 5.4 | -8.4 / -2.1 / 6.4 | 3.9 | 3.9
flat_params[20]: -1.1 +- 5.4 | -8.3 / -1.4 / 6.7 | -8.1 | -8.1
flat_params[21]: -1.5 +- 5.3 | -8.4 / -1.7 / 6.4 | -8.2 | -8.2
flat_params[22]: -1.5 +- 5.2 | -8.3 / -1.7 / 5.9 | 5.4 | 5.4
flat_params[23]: -1.9 +- 5.3 | -8.5 / -2.7 / 5.8 | -6.8 | -6.8
flat_params[24]: -1.3 +- 5.5 | -8.5 / -2.0 / 6.7 | 4.6 | 4.6
flat_params[25]: -1.7 +- 5.5 | -8.6 / -2.2 / 6.6 | 8.6 | 8.6
flat_params[26]: -1.0 +- 5.4 | -8.3 / -1.5 / 6.5 | 8.2 | 8.2
flat_params[27]: -1.4 +- 5.5 | -8.3 / -2.1 / 6.9 | -6.1 | -6.1
flat_params[28]: -1.1 +- 5.4 | -8.2 / -1.6 / 6.9 | 7.3 | 7.3
flat_params[29]: -1.7 +- 5.3 | -8.7 / -2.2 / 5.9 | -5.9 | -5.9
flat_params[30]: -0.4 +- 4.5 | -6.7 / -0.4 / 5.6 | 8.9 | 8.9
flat_params[31]: -0.1 +- 4.7 | -7.0 / -0.1 / 6.5 | -2.3 | -2.3
flat_params[32]: -0.1 +- 4.6 | -6.5 / -0.1 / 6.6 | 5.3 | 5.3
flat_params[33]: -0.1 +- 4.7 | -6.9 / -0.0 / 6.7 | -5.2 | -5.2
flat_params[34]: -0.0 +- 4.7 | -6.7 / 0.1 / 6.6 | -9.7 | -9.7
flat_params[35]: 0.2 +- 4.5 | -6.1 / 0.2 / 6.6 | 9.9 | 9.9
--------
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:45: UserWarning: Found samples with zero likelihood evaluations.
warnings.warn("Found samples with zero likelihood evaluations.")
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:49: RuntimeWarning: divide by zero encountered in divide
1. / num_likelihood_evaluations_per_sample
0: [0. 0. 0. 0. 0.] -> 0.0 | pred: 0.0001603582495590672 ✓
1: [0. 0. 0. 0. 1.] -> 1.0 | pred: 0.9999916553497314 ✓
2: [0. 0. 0. 1. 0.] -> 1.0 | pred: 0.9998642206192017 ✓
3: [0. 0. 0. 1. 1.] -> 0.0 | pred: 4.826125586987473e-05 ✓
4: [0. 0. 1. 0. 0.] -> 1.0 | pred: 0.9998642206192017 ✓
5: [0. 0. 1. 0. 1.] -> 0.0 | pred: 0.006293473765254021 ✓
6: [0. 0. 1. 1. 0.] -> 0.0 | pred: 4.47046147655783e-07 ✓
7: [0. 0. 1. 1. 1.] -> 1.0 | pred: 0.9998515844345093 ✓
8: [0. 1. 0. 0. 0.] -> 1.0 | pred: 0.9998642206192017 ✓
9: [0. 1. 0. 0. 1.] -> 0.0 | pred: 0.0005624053883366287 ✓
10: [0. 1. 0. 1. 0.] -> 0.0 | pred: 0.007503319066017866 ✓
11: [0. 1. 0. 1. 1.] -> 1.0 | pred: 0.9998642206192017 ✓
12: [0. 1. 1. 0. 0.] -> 0.0 | pred: 2.489727921783924e-05 ✓
13: [0. 1. 1. 0. 1.] -> 1.0 | pred: 0.9998642206192017 ✓
14: [0. 1. 1. 1. 0.] -> 1.0 | pred: 0.9988038539886475 ✓
15: [0. 1. 1. 1. 1.] -> 0.0 | pred: 2.1107047359691933e-05 ✓
16: [1. 0. 0. 0. 0.] -> 1.0 | pred: 0.9999055862426758 ✓
17: [1. 0. 0. 0. 1.] -> 0.0 | pred: 7.628910680068657e-06 ✓
18: [1. 0. 0. 1. 0.] -> 0.0 | pred: 1.0083437700814102e-05 ✓
19: [1. 0. 0. 1. 1.] -> 1.0 | pred: 0.9997050166130066 ✓
20: [1. 0. 1. 0. 0.] -> 0.0 | pred: 0.000562924484256655 ✓
21: [1. 0. 1. 0. 1.] -> 1.0 | pred: 0.9999998807907104 ✓
22: [1. 0. 1. 1. 0.] -> 1.0 | pred: 0.9943498969078064 ✓
23: [1. 0. 1. 1. 1.] -> 0.0 | pred: 0.00018032571824733168 ✓
24: [1. 1. 0. 0. 0.] -> 0.0 | pred: 0.00011755365267163143 ✓
25: [1. 1. 0. 0. 1.] -> 1.0 | pred: 0.9991588592529297 ✓
26: [1. 1. 0. 1. 0.] -> 1.0 | pred: 0.9964148998260498 ✓
27: [1. 1. 0. 1. 1.] -> 0.0 | pred: 3.537770317052491e-05 ✓
28: [1. 1. 1. 0. 0.] -> 1.0 | pred: 0.9990290403366089 ✓
29: [1. 1. 1. 0. 1.] -> 0.0 | pred: 6.320502143353224e-05 ✓
30: [1. 1. 1. 1. 0.] -> 0.0 | pred: 1.786692038763249e-08 ✓
31: [1. 1. 1. 1. 1.] -> 1.0 | pred: 0.9997976422309875 ✓
Accuracy: 100.0%
[5]:
run(num_variables - 1)
INFO:jaxns:Sanity check...
INFO:jaxns:Sanity check passed
Running over 12 devices.
--------
Termination Conditions:
absolute spread of live points < atol
--------
likelihood evals: 6574962
samples: 33300
phantom samples: 0
likelihood evals / sample: 197.4
phantom fraction (%): 0.0%
--------
logZ=-6.456 +- 0.078
max(logL)=-0.085
H=-3.66
ESS=1724
--------
flat_params[#]: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.
flat_params[0]: -3.2 +- 5.2 | -9.1 / -4.4 / 4.5 | 9.8 | 9.8
flat_params[1]: -3.4 +- 4.8 | -9.0 / -4.4 / 4.1 | -8.8 | -8.8
flat_params[2]: -3.3 +- 5.1 | -9.1 / -4.4 / 4.5 | 0.7 | 0.7
flat_params[3]: -3.1 +- 5.0 | -8.9 / -4.2 / 4.6 | 7.6 | 7.6
flat_params[4]: -1.6 +- 5.3 | -8.3 / -2.2 / 6.3 | -4.9 | -4.9
flat_params[5]: -1.9 +- 5.1 | -8.5 / -2.3 / 5.5 | -6.0 | -6.0
flat_params[6]: -1.4 +- 5.4 | -8.4 / -2.0 / 6.7 | -9.2 | -9.2
flat_params[7]: -1.9 +- 5.4 | -8.7 / -2.8 / 6.2 | -6.9 | -6.9
flat_params[8]: -1.7 +- 5.3 | -8.4 / -2.2 / 6.3 | -7.7 | -7.7
flat_params[9]: -1.7 +- 5.1 | -8.2 / -2.0 / 5.6 | -8.0 | -8.0
flat_params[10]: -1.5 +- 5.5 | -8.6 / -2.2 / 6.9 | -9.1 | -9.1
flat_params[11]: -1.3 +- 5.4 | -8.3 / -1.6 / 6.6 | -9.0 | -9.0
flat_params[12]: -1.7 +- 5.3 | -8.5 / -2.4 / 6.1 | 4.1 | 4.1
flat_params[13]: -1.5 +- 5.4 | -8.4 / -2.2 / 6.5 | 6.9 | 6.9
flat_params[14]: -1.7 +- 5.4 | -8.5 / -2.4 / 6.3 | 9.4 | 9.4
flat_params[15]: -1.8 +- 5.3 | -8.5 / -2.4 / 6.0 | 7.0 | 7.0
flat_params[16]: -1.5 +- 5.2 | -8.3 / -2.2 / 6.1 | -5.0 | -5.0
flat_params[17]: -1.8 +- 5.3 | -8.6 / -2.4 / 5.9 | -6.5 | -6.5
flat_params[18]: -1.8 +- 5.4 | -8.5 / -2.5 / 6.4 | -8.7 | -8.7
flat_params[19]: -2.2 +- 5.2 | -8.5 / -2.9 / 5.3 | -6.3 | -6.3
flat_params[20]: -1.3 +- 5.5 | -8.4 / -1.7 / 6.8 | 7.4 | 7.4
flat_params[21]: -1.7 +- 5.1 | -8.4 / -2.2 / 5.7 | 8.2 | 8.2
flat_params[22]: -1.1 +- 5.4 | -8.2 / -1.2 / 6.7 | 8.5 | 8.5
flat_params[23]: -1.3 +- 5.5 | -8.3 / -1.9 / 6.6 | 9.2 | 9.2
flat_params[24]: -0.0 +- 4.2 | -5.6 / -0.1 / 5.6 | -1.5 | -1.5
flat_params[25]: -0.0 +- 4.8 | -7.1 / 0.0 / 6.9 | 6.2 | 6.2
flat_params[26]: -0.0 +- 4.8 | -7.0 / 0.0 / 6.8 | -9.7 | -9.7
flat_params[27]: -0.2 +- 4.6 | -7.0 / -0.1 / 6.1 | 7.4 | 7.4
flat_params[28]: 0.2 +- 4.7 | -6.4 / 0.2 / 6.6 | -9.4 | -9.4
--------
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:45: UserWarning: Found samples with zero likelihood evaluations.
warnings.warn("Found samples with zero likelihood evaluations.")
/home/albert/miniconda3/envs/jaxns_py/lib/python3.11/site-packages/jaxns/plotting.py:49: RuntimeWarning: divide by zero encountered in divide
1. / num_likelihood_evaluations_per_sample
0: [0. 0. 0. 0. 0.] -> 0.0 | pred: 0.0005937305395491421 ✓
1: [0. 0. 0. 0. 1.] -> 1.0 | pred: 0.9999994039535522 ✓
2: [0. 0. 0. 1. 0.] -> 1.0 | pred: 0.9999998807907104 ✓
3: [0. 0. 0. 1. 1.] -> 0.0 | pred: 6.42026598551837e-10 ✓
4: [0. 0. 1. 0. 0.] -> 1.0 | pred: 1.0 ✓
5: [0. 0. 1. 0. 1.] -> 0.0 | pred: 6.972756239065347e-08 ✓
6: [0. 0. 1. 1. 0.] -> 0.0 | pred: 1.7175752873299643e-06 ✓
7: [0. 0. 1. 1. 1.] -> 1.0 | pred: 0.9997764229774475 ✓
8: [0. 1. 0. 0. 0.] -> 1.0 | pred: 0.9999908208847046 ✓
9: [0. 1. 0. 0. 1.] -> 0.0 | pred: 1.0467486077914145e-07 ✓
10: [0. 1. 0. 1. 0.] -> 0.0 | pred: 0.18118071556091309 ✓
11: [0. 1. 0. 1. 1.] -> 1.0 | pred: 0.999995231628418 ✓
12: [0. 1. 1. 0. 0.] -> 0.0 | pred: 0.00027995381969958544 ✓
13: [0. 1. 1. 0. 1.] -> 1.0 | pred: 0.9999985694885254 ✓
14: [0. 1. 1. 1. 0.] -> 1.0 | pred: 0.9977697134017944 ✓
15: [0. 1. 1. 1. 1.] -> 0.0 | pred: 3.0262933825575544e-10 ✓
16: [1. 0. 0. 0. 0.] -> 1.0 | pred: 1.0 ✓
17: [1. 0. 0. 0. 1.] -> 0.0 | pred: 1.2974471275128963e-08 ✓
18: [1. 0. 0. 1. 0.] -> 0.0 | pred: 0.18118071556091309 ✓
19: [1. 0. 0. 1. 1.] -> 1.0 | pred: 0.9999678134918213 ✓
20: [1. 0. 1. 0. 0.] -> 0.0 | pred: 3.4708678867900744e-05 ✓
21: [1. 0. 1. 0. 1.] -> 1.0 | pred: 0.9999445676803589 ✓
22: [1. 0. 1. 1. 0.] -> 1.0 | pred: 0.9999856948852539 ✓
23: [1. 0. 1. 1. 1.] -> 0.0 | pred: 3.751090063564e-11 ✓
24: [1. 1. 0. 0. 0.] -> 0.0 | pred: 0.18118071556091309 ✓
25: [1. 1. 0. 0. 1.] -> 1.0 | pred: 1.0 ✓
26: [1. 1. 0. 1. 0.] -> 1.0 | pred: 0.18118071556091309 ✗
27: [1. 1. 0. 1. 1.] -> 0.0 | pred: 0.18118071556091309 ✓
28: [1. 1. 1. 0. 0.] -> 1.0 | pred: 0.9989967942237854 ✓
29: [1. 1. 1. 0. 1.] -> 0.0 | pred: 6.115756789881743e-09 ✓
30: [1. 1. 1. 1. 0.] -> 0.0 | pred: 0.18118071556091309 ✓
31: [1. 1. 1. 1. 1.] -> 1.0 | pred: 0.9993218183517456 ✓
Accuracy: 96.9%