from typing import NamedTuple, List, Tuple, Union, Optional
from jax import numpy as jnp, lax, tree_map, core
from jax._src.numpy import lax_numpy
from jaxns.internals.maps import remove_chunk_dim
from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \
int_type
[docs]
class SampleTreeGraph(NamedTuple):
"""
Represents tree structure of samples.
There are N+1 nodes, and N edges.
Each node has exactly 1 sender (except the root node).
Each node has zero or more receivers.
The root is always node 0.
"""
[docs]
sender_node_idx: IntArray # [N] with values in [0, N]
[docs]
log_L: MeasureType # [N]
[docs]
class SampleLivePointCounts(NamedTuple):
[docs]
samples_indices: IntArray # [N] with values in [0, N], points to the sample that the live point represents.
[docs]
num_live_points: IntArray # [N] with values in [0, N], number of live points that the sample represents.
[docs]
def count_crossed_edges(sample_tree: SampleTreeGraph, num_samples: Optional[IntArray] = None) -> SampleLivePointCounts:
def _argsort(
a: jnp.ndarray
):
arr = jnp.asarray(a)
axis_num = 0
use_64bit_index = not core.is_constant_dim(arr.shape[axis_num]) or arr.shape[axis_num] >= (1 << 31)
iota = lax.broadcasted_iota(lax_numpy.int64 if use_64bit_index else lax_numpy.int_, arr.shape, axis_num)
_, perm = lax.sort_key_val(arr, iota, dimension=axis_num)
return perm
N = sample_tree.sender_node_idx.size
fake_edges = 0
if num_samples is not None:
# We put edges from root to +inf for the indices that are not used.
# Since all lines will cross these injected edges, we subtract them from the total.
# Note: these values are set by default, so we don't need to do anything.
# Leave this here as a reminder of what's going on.
# mask = jnp.arange(N) < num_samples
# sample_tree = SampleTreeGraph(
# sender_node_idx=jnp.where(mask, sample_tree.sender_node_idx, 0),
# log_L=jnp.where(mask, sample_tree.log_L, jnp.inf)
# )
fake_edges = N - num_samples
# Construct N edges from N+1 nodes
log_L_nodes = jnp.concatenate([jnp.asarray([-jnp.inf], float_type), sample_tree.log_L]) # [N+1]
sender = sample_tree.sender_node_idx # [N]
sort_idx = _argsort(log_L_nodes) # [N+1]
# Count out-degree of each node, how many nodes have parent_idx==idx.
# At least one node will have zero, but we don't know which.
# Could just use sender (unsorted)
out_degree = jnp.bincount(sender, length=N + 1) # [N+1]
def op(crossed_edges, last_node):
# init = 1
# delta = degree(nodes[last_node]) - 1
crossed_edges += out_degree[last_node] - 1
return crossed_edges
if num_samples is not None:
_, crossed_edges_sorted = cumulative_op_dynamic(
op=op,
init=jnp.asarray(1, out_degree.dtype),
xs=sort_idx,
stop_idx=num_samples,
pre_op=False,
empty_fill=jnp.asarray(fake_edges, out_degree.dtype)
)
else:
_, crossed_edges_sorted = cumulative_op_static(
op=op,
init=jnp.asarray(1, out_degree.dtype),
xs=sort_idx,
pre_op=False
)
if num_samples is not None:
crossed_edges_sorted -= fake_edges
# Since the root node is always 0, we need to slice and subtract 1 to get the sample index.
samples_indices = sort_idx[1:] - 1 # [N]
# The last node is the accumulation, which is always 0, so we drop it.
num_live_points = crossed_edges_sorted[:-1] # [N]
return SampleLivePointCounts(
samples_indices=samples_indices,
num_live_points=num_live_points
)
[docs]
def count_crossed_edges_less_fast(S: SampleTreeGraph) -> SampleLivePointCounts:
log_L = jnp.concatenate([-jnp.inf * jnp.ones(1), S.log_L]) # [N+1]
# Construct N edges from N+1 nodes
N = S.sender_node_idx.size
sender = S.sender_node_idx # [N]
sort_idx = jnp.argsort(log_L) # [N+1]
# Count out-degree of each node, how many nodes have parent_idx==idx.
# At least one node will have zero, but we don't know which.
# Could just use sender (unsorted)
out_degree = jnp.bincount(sender, length=N + 1) # [N+1]
crossed_edges_sorted = 1 + jnp.cumsum(out_degree[sort_idx] - 1)
# Since the root node is always 0, we need to slice and subtract 1 to get the sample index.
samples_indices = sort_idx[1:] - 1 # [N]
# The last node is the accumulation, which is always 0, so we drop it.
num_live_points = crossed_edges_sorted[:-1] # [N]
return SampleLivePointCounts(
samples_indices=samples_indices,
num_live_points=num_live_points
)
[docs]
def count_intervals_naive(S: SampleTreeGraph) -> SampleLivePointCounts:
# We use the simple method, of counting the number that satisfy the selection condition
log_L = jnp.concatenate([-jnp.inf * jnp.ones(1), S.log_L]) # [N+1]
log_L_constraints = log_L[S.sender_node_idx] # [N]
sort_idx = jnp.argsort(log_L[1:]) # [N]
N = S.sender_node_idx.size
available = jnp.ones(N, dtype=jnp.bool_)
contour = log_L[0] # Root is always 0 index
num_live_points = jnp.zeros(N, dtype=jnp.int32)
for i in range(N):
mask = (log_L_constraints[sort_idx] <= contour) & (log_L[1:][sort_idx] > contour) & available[sort_idx] # [N]
num_live_points = num_live_points.at[i].set(jnp.sum(mask))
contour = log_L[sort_idx[i] + 1]
available = available.at[sort_idx[i]].set(False)
return SampleLivePointCounts(
samples_indices=sort_idx,
num_live_points=num_live_points
)
[docs]
def fast_perfect_live_point_computation_jax(log_L_constraints: jnp.ndarray, log_L_samples: jnp.ndarray,
num_samples: Union[jnp.ndarray, None] = None):
# log_L_constraints has shape [N]
# log_L_samples has shape [N]
sort_idx = jnp.lexsort((log_L_constraints, log_L_samples))
log_L_samples = log_L_samples[sort_idx]
log_L_constraints = log_L_constraints[sort_idx]
log_L_contour = log_L_constraints[0]
search_contours = jnp.concatenate([log_L_contour[None], log_L_samples], axis=0)
contour_map_idx = jnp.searchsorted(search_contours, log_L_samples, side='left') - 1
log_L_contours = search_contours[contour_map_idx]
diag_i = jnp.arange(log_L_samples.size)
right_most_idx = jnp.searchsorted(jnp.sort(log_L_constraints), log_L_contours, side='right') - 1
left_most_idx = jnp.maximum(diag_i, jnp.searchsorted(log_L_samples, log_L_contours, side='right') - 1)
num_live_points = jnp.maximum(0, right_most_idx - left_most_idx + 1)
if num_samples is not None:
empty_mask = jnp.greater_equal(jnp.arange(log_L_samples.size), num_samples)
num_live_points = jnp.where(empty_mask, jnp.asarray(0., log_L_samples.dtype), num_live_points)
return num_live_points, sort_idx
[docs]
def compute_num_live_points_from_unit_threads(log_L_constraints: FloatArray, log_L_samples: FloatArray,
num_samples: IntArray = None, sorted_collection: bool = True) \
-> Union[FloatArray, Tuple[FloatArray, IntArray]]:
"""
Compute the number of live points of shrinkage distribution, from an arbitrary list of samples with
corresponding sampling constraints.
Args:
log_L_constraints: [N] likelihood constraint that sample was uniformly sampled within
log_L_samples: [N] likelihood of the sample
sorted_collection: bool, whether the sample collection was already sorted.
Returns:
if sorted_collection is true:
num_live_points for shrinkage distribution
otherwise:
num_live_points for shrinkage distribution, and sort indicies
"""
num_live_points, sort_idx = fast_perfect_live_point_computation_jax(log_L_constraints=log_L_constraints,
log_L_samples=log_L_samples,
num_samples=num_samples)
if not sorted_collection:
return num_live_points, sort_idx
return num_live_points
[docs]
def count_old(S: SampleTreeGraph) -> SampleLivePointCounts:
log_L = jnp.concatenate([-jnp.inf * jnp.ones(1), S.log_L]) # [N+1]
# 3x slower than new method
log_L_constraints = log_L[S.sender_node_idx] # [N]
log_L_samples = S.log_L
num_live_points, sort_idx = compute_num_live_points_from_unit_threads(
log_L_constraints=log_L_constraints,
log_L_samples=log_L_samples,
num_samples=None,
sorted_collection=False
)
return SampleLivePointCounts(
samples_indices=sort_idx,
num_live_points=num_live_points
)
[docs]
def plot_tree(S: SampleTreeGraph):
r"""
Plots the tree where x-position is log_L and y-position is a unique integer for each branch such that no edges
cross. The y-position should be the same as it's sender's y-position, unless that would make an edge cross,
in which case, an addition should be made so that no edges cross.
e.g.
For the tree:
S = SampleTree(
sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]),
log_L=jnp.asarray([-jnp.inf, 1, 2, 3, 4, 5, 6])
)
The root node connects to nodes 1, 2, 3, and then it's straight lines from 1, 2, 3 to 4, 5, 6.
The ASCII plot is:
0 -- 1 -- 4
\-- 2 -- 5
\- 3 -- 6
If we add a branch from 2 to 7 then we get:
0 --- 1 -- 4
\--- 2 -- 5
\ \ -- 7
\-- 3 -- 6
See how 2-7 edge doesn't cross any other edges.
Note the in-degree of each node is 1, except the root node which has in-degree 0.
The out-degree can be anything.
Args:
S: SampleTree
"""
import networkx as nx
import pylab as plt
# Initialize graph
G = nx.DiGraph()
# Add edges and nodes to the graph
for idx, sender in enumerate(S.sender_node_idx):
node = idx + 1
sender = int(sender)
G.add_node(node, x=float(S.log_L[node]), sender=sender)
G.add_edge(sender, node)
G.nodes[0]['x'] = float(jnp.min(S.log_L) - 1.)
G.nodes[0]['y'] = 0
G.nodes[0]['sender'] = -1
out_degree = jnp.bincount(S.sender_node_idx, length=S.log_L.size) # [N+1]
# Dictionary to store the positions of each node
visited = []
branch = 0
for node in nx.traversal.dfs_tree(G, 0):
if node == 0:
continue
# print(G.nodes[node]['sender'], node, visited.count(G.nodes[node]['sender']))
G.nodes[node]['y'] = G.nodes[G.nodes[node]['sender']]['y'] # + visited.count(G.nodes[node]['sender'])
# if out_degree[G.nodes[node]['sender']] > 1:
if visited.count(G.nodes[node]['sender']) > 0:
branch += visited.count(G.nodes[node]['sender'])
G.nodes[node]['y'] = branch
visited.append(G.nodes[node]['sender'])
pos = dict((node, (G.nodes[node]['x'], G.nodes[node]['y'])) for node in G.nodes)
# Draw the graph
nx.draw(G, pos, with_labels=True, node_size=700, node_color='lightblue', linewidths=0.5, font_size=10, arrows=True)
# Display the plot
plt.show()
[docs]
def concatenate_sample_trees(trees: List[SampleTreeGraph],
num_samples: Optional[List[IntArray]] = None) -> SampleTreeGraph:
"""
Concatenates a list of SampleTreeGraphs into a single SampleTreeGraph.
The root nodes of each tree must be the same, as this is equivalent of adding each tree as a branch from the same
root node, 0 in each tree.
Args:
trees: list of SampleTreeGraph's
Returns:
a single tree
"""
# To do this all sender_node_idx must point to proper new nodes:
# Example:
# 0 -- 1 -- 2
# 0 -- 1 -- 2
# Becomes:
# 0 -- 1 -- 2
# \-- 3 -- 4
if num_samples is None:
num_samples = [t.sender_node_idx.size for t in trees]
if len(num_samples) != len(trees):
raise ValueError("num_samples must be same length as trees.")
offset = 0
shifted_trees = []
for s, t in zip(num_samples, trees):
shifted_trees.append(
SampleTreeGraph(
sender_node_idx=jnp.where(t.sender_node_idx.astype(jnp.bool_), t.sender_node_idx + offset,
t.sender_node_idx),
log_L=t.log_L
)
)
offset += s
output = SampleTreeGraph(
sender_node_idx=jnp.concatenate([t.sender_node_idx for t in shifted_trees]),
log_L=jnp.concatenate([t.log_L for t in shifted_trees])
)
return output
[docs]
def unbatch_state(batched_state: StaticStandardNestedSamplerState) -> StaticStandardNestedSamplerState:
"""
Remove the batch dimension from the state. The returned samples will be sorted by log_L,
so assumes,
log_L[i]==+inf ==> i is not a sample
Args:
batched_state: the state with batch dimension
Returns:
the state without batch dimension
"""
if len(batched_state.sample_collection.log_L.shape) == 1:
# Already unbatched
return batched_state
if batched_state.sample_collection.log_L.shape[0] == 1:
# Remove batch dimension is all that's needed
return remove_chunk_dim(batched_state)
key = batched_state.key[0] # Take first key
next_sample_idx = jnp.sum(batched_state.next_sample_idx) # Next insert will be sum
# Shifts are the cumulative sum of the number of samples per batch dimension
shifts = [0]
for i in range(len(batched_state.next_sample_idx) - 1):
shifts.append(shifts[-1] + batched_state.sample_collection.log_L.shape[1])
shifts = jnp.asarray(shifts, int_type)
# shifts = jnp.concatenate([jnp.asarray([0], int_type), jnp.cumsum(batched_state.next_sample_idx[:-1])])
sender_node_idx = jnp.where(
batched_state.sample_collection.sender_node_idx.astype(jnp.bool_),
batched_state.sample_collection.sender_node_idx + shifts[:, None],
batched_state.sample_collection.sender_node_idx
)
# Front indices are shifted like senders
front_idx = remove_chunk_dim(
batched_state.front_idx + shifts[:, None]
)
unbatched_state = StaticStandardNestedSamplerState(
key=key,
next_sample_idx=next_sample_idx,
sample_collection=remove_chunk_dim(
batched_state.sample_collection._replace(sender_node_idx=sender_node_idx)
),
front_idx=front_idx
)
# Some non-samples will interleave samples, so we sort by log_L, carefully adjusting sender_node_idx to match.
sort_idx = jnp.argsort(unbatched_state.sample_collection.log_L)
inverse_idx = jnp.argsort(sort_idx)
# Shift the front_idx and sender idx
front_idx = inverse_idx[unbatched_state.front_idx]
sender_node_idx = jnp.where(
unbatched_state.sample_collection.sender_node_idx.astype(jnp.bool_),
inverse_idx[unbatched_state.sample_collection.sender_node_idx - 1] + 1,
jnp.zeros_like(unbatched_state.sample_collection.sender_node_idx)
)
unbatched_state = unbatched_state._replace(
sample_collection=unbatched_state.sample_collection._replace(sender_node_idx=sender_node_idx)
)
# Rearrange the samples
unbatched_state = unbatched_state._replace(
sample_collection=tree_map(lambda x: x[sort_idx], unbatched_state.sample_collection),
front_idx=front_idx
)
return unbatched_state