tree_structure
jaxns.internals.tree_structure
Module Contents
- class SampleTreeGraph[source]
Bases:
NamedTuple
Represents tree structure of samples. There are N+1 nodes, and N edges. Each node has exactly 1 sender (except the root node). Each node has zero or more receivers. The root is always node 0.
- count_crossed_edges(sample_tree, num_samples=None)[source]
- Parameters:
sample_tree (SampleTreeGraph) –
num_samples (Optional[jaxns.internals.types.IntArray]) –
- Return type:
- count_crossed_edges_less_fast(S)[source]
- Parameters:
S (SampleTreeGraph) –
- Return type:
- count_intervals_naive(S)[source]
- Parameters:
S (SampleTreeGraph) –
- Return type:
- fast_perfect_live_point_computation_jax(log_L_constraints, log_L_samples, num_samples=None)[source]
- Parameters:
log_L_constraints (jax.numpy.ndarray) –
log_L_samples (jax.numpy.ndarray) –
num_samples (Union[jax.numpy.ndarray, None]) –
- compute_num_live_points_from_unit_threads(log_L_constraints, log_L_samples, num_samples=None, sorted_collection=True)[source]
Compute the number of live points of shrinkage distribution, from an arbitrary list of samples with corresponding sampling constraints.
- Parameters:
log_L_constraints (jaxns.internals.types.FloatArray) – [N] likelihood constraint that sample was uniformly sampled within
log_L_samples (jaxns.internals.types.FloatArray) – [N] likelihood of the sample
sorted_collection (bool) – bool, whether the sample collection was already sorted.
num_samples (jaxns.internals.types.IntArray) –
- Returns:
num_live_points for shrinkage distribution otherwise:
num_live_points for shrinkage distribution, and sort indicies
- Return type:
if sorted_collection is true
- count_old(S)[source]
- Parameters:
S (SampleTreeGraph) –
- Return type:
- plot_tree(S)[source]
Plots the tree where x-position is log_L and y-position is a unique integer for each branch such that no edges cross. The y-position should be the same as it’s sender’s y-position, unless that would make an edge cross, in which case, an addition should be made so that no edges cross.
e.g.
For the tree:
- S = SampleTree(
sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), log_L=jnp.asarray([-jnp.inf, 1, 2, 3, 4, 5, 6])
)
The root node connects to nodes 1, 2, 3, and then it’s straight lines from 1, 2, 3 to 4, 5, 6.
The ASCII plot is:
- 0 – 1 – 4
- -- 2 – 5
- 3 – 6
If we add a branch from 2 to 7 then we get:
- 0 — 1 – 4
- -– 2 – 5
- – 7
-- 3 – 6
See how 2-7 edge doesn’t cross any other edges.
Note the in-degree of each node is 1, except the root node which has in-degree 0. The out-degree can be anything.
- Parameters:
S (SampleTreeGraph) – SampleTree
- concatenate_sample_trees(trees, num_samples=None)[source]
Concatenates a list of SampleTreeGraphs into a single SampleTreeGraph.
The root nodes of each tree must be the same, as this is equivalent of adding each tree as a branch from the same root node, 0 in each tree.
- Parameters:
trees (List[SampleTreeGraph]) – list of SampleTreeGraph’s
num_samples (Optional[List[jaxns.internals.types.IntArray]]) –
- Returns:
a single tree
- Return type:
- unbatch_state(batched_state)[source]
Remove the batch dimension from the state. The returned samples will be sorted by log_L, so assumes,
log_L[i]==+inf ==> i is not a sample
- Parameters:
batched_state (jaxns.internals.types.StaticStandardNestedSamplerState) – the state with batch dimension
- Returns:
the state without batch dimension
- Return type: