Source code for jaxns.plotting

import logging
from typing import Optional, List, Union

import jax.numpy as jnp
import numpy as np
import pylab as plt
from jax import random
from matplotlib.animation import FuncAnimation
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.stats import gaussian_kde

from jaxns.internals.log_semiring import cumulative_logsumexp, LogSpace, normalise_log_space
from jaxns.internals.shapes import tuple_prod
from jaxns.internals.types import NestedSamplerResults, int_type
from jaxns.utils import resample

logger = logging.getLogger('jaxns')

__all__ = ['plot_diagnostics',
           'plot_cornerplot']


[docs] def plot_diagnostics(results: NestedSamplerResults, save_name=None): """ Plot diagnostics of the nested sampling run. Args: results: NestedSamplerResult save_name: file to save figure to. """ num_samples = int(results.total_num_samples) if results.log_L_samples.shape[0] != num_samples: raise ValueError(f"Expected all samples to have the same number of samples, " f"got log_L_samples with {results.log_L_samples.shape[0]} samples, " f"expected {num_samples} samples.") fig, axs = plt.subplots(5, 1, sharex=True, figsize=(8, 12)) log_X = np.asarray(results.log_X_mean) num_live_points_per_sample = np.asarray(results.num_live_points_per_sample) log_L = np.asarray(results.log_L_samples) max_log_likelihood = np.max(log_L) log_dp_mean = np.asarray(results.log_dp_mean) log_cum_evidence = cumulative_logsumexp(log_dp_mean) cum_evidence = np.exp(log_cum_evidence) log_Z_mean = np.asarray(results.log_Z_mean) num_likelihood_evaluations_per_sample = np.asarray(results.num_likelihood_evaluations_per_sample) efficiency = 1. / num_likelihood_evaluations_per_sample mean_efficiency = np.exp(results.log_efficiency) # Plot the number of live points axs[0].plot(-log_X, num_live_points_per_sample, c='black') axs[0].set_ylabel(r'$n_{\rm live}$') # detect if too small log likelihood rel_log_L = log_L - max_log_likelihood axs[1].plot(-log_X, np.exp(rel_log_L), c='black') axs[1].axhline(1., color='black', ls='dashed', label=rf"$\log L_{{\rm max}}={max_log_likelihood:.1f}$") axs[1].set_ylabel(r'$L/L_{\rm max}$') axs[1].legend() axs[2].plot(-log_X, np.exp(log_dp_mean), c='black') axs[2].axvline(-results.H_mean, color='black', ls='dashed', label=rf'$-H={-results.H_mean:.1f}$') axs[2].set_ylabel(r'$Z^{-1}L dX$') axs[2].legend() axs[3].plot(-log_X, cum_evidence, c='black') axs[3].axhline(1., color='black', ls='dashed', label=rf"$\log Z={log_Z_mean:.1f}$") axs[3].set_ylabel(r'$Z(\lambda > L)/Z$') axs[3].legend() axs[4].scatter(-log_X, efficiency, s=2, c='black') axs[4].axhline(mean_efficiency, color='black', ls='dashed', label=f'avg. eff.={mean_efficiency:.3f}') axs[4].set_ylabel("sampler efficiency") axs[4].set_ylim(0., 1.05) axs[4].legend() axs[4].set_xlabel(r'$- \log X$') if save_name is not None: fig.savefig(save_name, bbox_inches='tight', dpi=300, pad_inches=0.0) plt.show()
[docs] def plot_cornerplot(results: NestedSamplerResults, variables: Optional[List[str]] = None, with_parametrised:bool=False, save_name: Optional[str] = None, kde_overlay: bool = False): """ Plots a cornerplot of the posterior samples. Args: results: NestedSamplerResult variables: list of variable names to plot. Plots all collected samples by default. save_name: file to save result to. kde_overlay: whether to overlay a KDE on the histograms. """ samples = results.samples if with_parametrised: samples.update(results.parametrised_samples) # Plot all variables by default if variables is None: variables = list(samples.keys()) variables = sorted(filter(lambda v: v in samples, variables)) ndims = sum([tuple_prod(samples[key].shape[1:]) for key in variables], 0) num_samples = int(results.total_num_samples) for key in variables: if samples[key].shape[0] != num_samples: raise ValueError(f"Expected all samples to have the same number of samples, " f"got {key} with {samples[key].shape[0]} samples, " f"expected {num_samples} samples.") # Get the leaves of the tree, and concatenate into [num_samples, ndims] shape leaves = np.concatenate( [np.asarray(samples[key]).reshape((num_samples, -1)) for key in variables], axis=-1 ) # Create a parameter for each column. For scalar parameters, we just use the name of the parameter. # For vector we use name[i,j,...] etc. parameters = [] for key in variables: shape = samples[key].shape[1:] if tuple_prod(shape) == 1: parameters.append(key) else: # Loop over each dimension of the parameter, and create a parameter for each index for i in range(tuple_prod(shape)): indices = np.unravel_index(i, shape) parameters.append(f"{key}[{','.join([str(j) for j in indices])}]") # Get the maximum likelihood and MAP samples log_L_samples = np.asarray(results.log_L_samples) log_posterior_density = np.asarray(results.log_posterior_density) max_like_idx = np.argmax(log_L_samples) map_idx = np.argmax(log_posterior_density) max_like_sample = leaves[max_like_idx] map_sample = leaves[map_idx] # Get the weight of each sample log_weights = np.asarray(normalise_log_space(LogSpace(results.log_dp_mean), norm_type='max').log_abs_val) figsize = min(20, max(4, int(2 * ndims))) fig, axs = plt.subplots(ndims, ndims, figsize=(figsize, figsize), squeeze=False) # Get the number of bins for the histograms based on the effective sample size nbins = max(10, int(jnp.sqrt(results.ESS))) # Loop over the variables, and plot the marginal distributions on the diagonal setting a title above # each plot with the mean+-stddev, 5%/50%/95%, and MAP param_limits = dict() # Store the 1_per and 99_per for each parameter for row in range(ndims): for col in range(ndims): ax = axs[row][col] if row != col: # i == j ==> plot the marginal distribution continue # Plot the marginal distribution _samples = leaves[:, row] # [num_samples] _parameter = parameters[row] _log_weights = log_weights is_finite = np.isfinite(_samples) if np.bitwise_not(np.all(is_finite)): logger.warning(f"Found {np.sum(np.bitwise_not(is_finite))} non-finite samples for {_parameter}") _samples = _samples[is_finite] _log_weights = _log_weights[is_finite] _weights = np.exp(_log_weights) # Percentiles per_1, per_5, per_50, per_95, per_99 = weighted_percentile(_samples, _log_weights, [1, 5, 50, 95, 99]) # Plot the histogram, from 1_per to 99_per ax.hist(_samples, bins=nbins, fc='None', edgecolor='black', density=True, weights=_weights, range=(per_1, per_99)) # Plot the maximum likelihood and MAP samples ax.axvline(max_like_sample[row], color='green') ax.axvline(map_sample[row], color='red') # Plot the mean and standard deviation sample_mean = np.average(_samples, weights=_weights) sample_std = np.sqrt(np.average((_samples - sample_mean) ** 2, weights=_weights)) ax.axvline(sample_mean, linestyle='dashed', color='red') ax.axvline(sample_mean + sample_std, linestyle='dotted', color='red') ax.axvline(sample_mean - sample_std, linestyle='dotted', color='red') # Set the title title = [ rf"${per_50:.2f}_{{{per_5:.2f}}}^{{{per_95:.2f}}}$", rf"${sample_mean:.2f}\pm{sample_std:.2f}$", rf"MAP ${map_sample[row]:.2f}$ | ML ${max_like_sample[row]:.2f}$" ] ax.set_title("\n".join(title)) # Set the limits to 1 to 99 percentiles ax.set_xlim(per_1, per_99) param_limits[_parameter] = (per_1, per_99) # Plot the 2D histograms on lower-diagonal. for row in range(ndims): for col in range(ndims): ax = axs[row][col] if col >= row: continue # Get the samples for the 2D histogram _samples = leaves[:, [row, col]] # [num_samples, 2] _log_weights = log_weights is_finite = np.all(np.isfinite(_samples), axis=-1) # [num_samples] if np.bitwise_not(np.all(is_finite)): logger.warning( f"Found {np.sum(np.bitwise_not(is_finite))} non-finite samples for {parameters[row]} and {parameters[col]}") _samples = _samples[is_finite] _log_weights = _log_weights[is_finite] _weights = np.exp(_log_weights) # Plot the 2D histogram, over ranges set by the 1_per and 99_per of each parameter ranges = [param_limits[parameters[col]], param_limits[parameters[row]]] ax.hist2d(_samples[:, 1], _samples[:, 0], bins=(nbins, nbins), density=True, cmap=plt.cm.get_cmap('bone_r'), weights=_weights, range=ranges) if kde_overlay: # Put KDE contour on the 2D histograms # Calculate the point density x = _samples[:, 1] y = _samples[:, 0] xy = np.vstack([x, y]) x_array = np.linspace(*param_limits[parameters[col]], 128) y_array = np.linspace(*param_limits[parameters[row]], 128) X, Y = np.meshgrid(x_array, y_array) xy_eval = np.vstack([X.ravel(), Y.ravel()]) z = gaussian_kde(xy, weights=_weights)(xy_eval) z = z.reshape(X.shape) ax.contour(X, Y, z, levels=6, alpha=0.5) # Plot the maximum likelihood and MAP samples ax.scatter(max_like_sample[col], max_like_sample[row], color='green', marker='x') ax.scatter(map_sample[col], map_sample[row], color='red', marker='x') # Set the limits to 1 to 99 percentiles ax.set_xlim(param_limits[parameters[col]]) ax.set_ylim(param_limits[parameters[row]]) # Remove spacing plt.subplots_adjust(wspace=0.0, hspace=0.0) # Remove x ticks for all but bottom row for row in range(ndims - 1): for col in range(ndims): axs[row][col].set_xticks([]) axs[row][col].set_xticklabels([]) # Remove y ticks for all but left column for row in range(ndims): for col in range(1, ndims): axs[row][col].set_yticks([]) axs[row][col].set_yticklabels([]) # Set the labels on the bottom row and left column for i in range(ndims): axs[-1][i].set_xlabel(parameters[i]) axs[i][0].set_ylabel(parameters[i]) # Remove upper diagonal for row in range(ndims): for col in range(ndims): if col <= row: continue axs[row][col].remove() # Save the figure if save_name is not None: fig.savefig(save_name, bbox_inches='tight', dpi=300, pad_inches=0.0) plt.show()
def weighted_percentile(samples: np.ndarray, log_weights: np.ndarray, percentiles: List[Union[float, int]]) -> np.ndarray: """ Compute weighted percentiles of a set of samples. Args: samples: weighted samples log_weights: log weights of samples percentiles: list of percentiles to compute Returns: weighted percentiles """ if len(percentiles) == 0: raise ValueError("percentiles must be a non-empty list") # Convert log weights to actual weights weights = LogSpace(log_weights - np.max(log_weights)) # Subtract max to avoid overflow weights = normalise_log_space(weights, norm_type='sum') # Normalize weights # Sort samples and weights sorted_indices = np.argsort(samples) sorted_samples = samples[sorted_indices] sorted_weights = weights[sorted_indices] # Compute cumulative weights cumulative_weights = sorted_weights.cumsum() cumulative_weights = cumulative_weights - cumulative_weights[0] cumulative_weights = cumulative_weights / cumulative_weights[-1] # Add zero to start of cumulative weights # Compute weighted percentiles percentile_values = np.interp(np.asarray(percentiles) / 100.0, cumulative_weights.value, sorted_samples) return percentile_values def plot_samples_development(results, variables=None, save_name=None): """ Animate the live points in a corner plot, visualising how the algorithm proceeds. Caution, this can be very slow as it plots a frame per sample. Args: results: NestedSamplingResult vars: list of variable names to plot, or None save_name: '.mp4' file to save animation to. """ if save_name is None: raise ValueError("In order to plot the animation we must save it.") # Plot all variables by default if variables is None: variables = list(results.samples.keys()) variables = sorted(filter(lambda v: v in results.samples, variables)) ndims = sum([tuple_prod(results.samples[key].shape[1:]) for key in variables], 0) figsize = min(20, max(4, int(2 * ndims))) fig, axs = plt.subplots(ndims, ndims, figsize=(figsize, figsize)) if ndims == 1: axs = [[axs]] weights = jnp.exp(results.log_p_mean) max_samples = weights.size norm = plt.Normalize(weights.min(), weights.max()) to_colour = lambda w: plt.cm.jet(norm(w)) def _get_artists(artists, start, stop): lims = {} dim = 0 for key in variables: # sorted(results.samples.keys()): n1 = tuple_prod(results.samples[key].shape[1:]) for i in range(n1): samples1 = results.samples[key].reshape((max_samples, -1))[:, i] samples1 = samples1[start:stop] dim2 = 0 for key2 in variables: # sorted(results.samples.keys()): n2 = tuple_prod(results.samples[key2].shape[1:]) for i2 in range(n2): ax = axs[dim][dim2] if dim2 > dim: dim2 += 1 ax.set_xticks([]) ax.set_xticklabels([]) ax.set_yticks([]) ax.set_yticklabels([]) continue if n2 > 1: title2 = "{}[{}]".format(key2, i2) else: title2 = "{}".format(key2) if n1 > 1: title1 = "{}[{}]".format(key, i) else: title1 = "{}".format(key) # ax.set_title('{} {}'.format(title1, title2)) if dim == dim2: _, _, new_patches = ax.hist(samples1) artists = artists + list(new_patches) lims[dim] = ax.get_xlim() else: samples2 = results.samples[key2].reshape((max_samples, -1))[:, i2] samples2 = samples2[start:stop] sc = ax.scatter(samples2, samples1, marker='+', c=to_colour(weights[start:stop]), alpha=0.3) artists.append(sc) if dim == ndims - 1: ax.set_xlabel("{}".format(title2)) if dim2 == 0: ax.set_ylabel("{}".format(title1)) dim2 += 1 dim += 1 for dim in range(ndims): for dim2 in range(ndims): if dim == dim2: continue ax = axs[dim][dim2] if ndims > 1 else axs[0] if dim in lims.keys(): ax.set_ylim(lims[dim]) if dim2 in lims.keys(): ax.set_xlim(lims[dim2]) return artists def init(): start = 0 stop = start + results.n_per_sample[start].astype(int_type) for i in range(ndims): for j in range(ndims): axs[i][j].clear() artists = [] artists = _get_artists(artists, start, stop) return artists def update(start): stop = start + results.n_per_sample[start].astype(int_type) for i in range(ndims): for j in range(ndims): axs[i][j].clear() artists = [] artists = _get_artists(artists, start, stop) return artists ani = FuncAnimation(fig, update, frames=jnp.arange(1, results.num_samples), init_func=init, blit=True) ani.save(save_name, fps=results.n_per_sample[0] / 2.) def add_colorbar_to_axes(ax, cmap, norm=None, vmin=None, vmax=None, label=None): """ Add colorbar to axes easily. Args: ax: Axes cmap: str or cmap norm: Normalize or None vmin: lower limit of color if norm is None vmax: upper limit of color if norm is None """ divider = make_axes_locatable(ax) cax = divider.append_axes('right', size='5%', pad=0.05) if norm is None: norm = plt.Normalize(vmin=vmin, vmax=vmax) sm = plt.cm.ScalarMappable(norm, cmap=plt.cm.get_cmap(cmap)) if label is None: ax.figure.colorbar(sm, cax=cax, orientation='vertical') else: ax.figure.colorbar(sm, cax=cax, orientation='vertical', label=label) def corner_cornerplot(results: NestedSamplerResults): try: import corner except ImportError: logger.warning("You must run `pip install corner`") exit(0) try: import arviz as az except ImportError: logger.warning("You must run `pip install arviz`") exit(0) from jax import tree_map samples = resample(random.PRNGKey(42), results.samples, results.log_dp_mean, S=int(results.ESS)) corner.corner(az.from_dict(posterior=tree_map(lambda x: x[None], samples)), ) plt.show()