tree_structure

jaxns.internals.tree_structure

Module Contents

class SampleTreeGraph[source]

Bases: NamedTuple

Represents tree structure of samples. There are N+1 nodes, and N edges. Each node has exactly 1 sender (except the root node). Each node has zero or more receivers. The root is always node 0.

sender_node_idx: jaxns.internals.types.IntArray[source]
log_L: jaxns.internals.types.MeasureType[source]
class SampleLivePointCounts[source]

Bases: NamedTuple

samples_indices: jaxns.internals.types.IntArray[source]
num_live_points: jaxns.internals.types.IntArray[source]
count_crossed_edges(sample_tree, num_samples=None)[source]
Parameters:
  • sample_tree (SampleTreeGraph) –

  • num_samples (Optional[jaxns.internals.types.IntArray]) –

Return type:

SampleLivePointCounts

count_crossed_edges_less_fast(S)[source]
Parameters:

S (SampleTreeGraph) –

Return type:

SampleLivePointCounts

count_intervals_naive(S)[source]
Parameters:

S (SampleTreeGraph) –

Return type:

SampleLivePointCounts

fast_perfect_live_point_computation_jax(log_L_constraints, log_L_samples, num_samples=None)[source]
Parameters:
  • log_L_constraints (jax.numpy.ndarray) –

  • log_L_samples (jax.numpy.ndarray) –

  • num_samples (Union[jax.numpy.ndarray, None]) –

compute_num_live_points_from_unit_threads(log_L_constraints, log_L_samples, num_samples=None, sorted_collection=True)[source]

Compute the number of live points of shrinkage distribution, from an arbitrary list of samples with corresponding sampling constraints.

Parameters:
  • log_L_constraints (jaxns.internals.types.FloatArray) – [N] likelihood constraint that sample was uniformly sampled within

  • log_L_samples (jaxns.internals.types.FloatArray) – [N] likelihood of the sample

  • sorted_collection (bool) – bool, whether the sample collection was already sorted.

  • num_samples (jaxns.internals.types.IntArray) –

Returns:

num_live_points for shrinkage distribution otherwise:

num_live_points for shrinkage distribution, and sort indicies

Return type:

if sorted_collection is true

count_old(S)[source]
Parameters:

S (SampleTreeGraph) –

Return type:

SampleLivePointCounts

plot_tree(S)[source]

Plots the tree where x-position is log_L and y-position is a unique integer for each branch such that no edges cross. The y-position should be the same as it’s sender’s y-position, unless that would make an edge cross, in which case, an addition should be made so that no edges cross.

e.g.

For the tree:

S = SampleTree(

sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), log_L=jnp.asarray([-jnp.inf, 1, 2, 3, 4, 5, 6])

)

The root node connects to nodes 1, 2, 3, and then it’s straight lines from 1, 2, 3 to 4, 5, 6.

The ASCII plot is:

0 – 1 – 4
-- 2 – 5

- 3 – 6

If we add a branch from 2 to 7 then we get:

0 — 1 – 4
-– 2 – 5
– 7

-- 3 – 6

See how 2-7 edge doesn’t cross any other edges.

Note the in-degree of each node is 1, except the root node which has in-degree 0. The out-degree can be anything.

Parameters:

S (SampleTreeGraph) – SampleTree

concatenate_sample_trees(trees, num_samples=None)[source]

Concatenates a list of SampleTreeGraphs into a single SampleTreeGraph.

The root nodes of each tree must be the same, as this is equivalent of adding each tree as a branch from the same root node, 0 in each tree.

Parameters:
  • trees (List[SampleTreeGraph]) – list of SampleTreeGraph’s

  • num_samples (Optional[List[jaxns.internals.types.IntArray]]) –

Returns:

a single tree

Return type:

SampleTreeGraph

unbatch_state(batched_state)[source]

Remove the batch dimension from the state. The returned samples will be sorted by log_L, so assumes,

log_L[i]==+inf ==> i is not a sample

Parameters:

batched_state (jaxns.internals.types.StaticStandardNestedSamplerState) – the state with batch dimension

Returns:

the state without batch dimension

Return type:

jaxns.internals.types.StaticStandardNestedSamplerState