tree_structure ======================== .. py:module:: jaxns.internals.tree_structure .. rubric:: :code:`jaxns.internals.tree_structure` .. rubric:: Module Contents .. py:class:: SampleTreeGraph Bases: :py:obj:`NamedTuple` Represents tree structure of samples. There are N+1 nodes, and N edges. Each node has exactly 1 sender (except the root node). Each node has zero or more receivers. The root is always node 0. :param sender_node_idx: [N] with values in [0, N], the sender node for each node. :param log_L: [N] with values in [-inf, +inf], the log likelihood of each node. .. py:attribute:: sender_node_idx :type: jaxns.internals.types.IntArray .. py:attribute:: log_L :type: jaxns.internals.types.MeasureType .. py:class:: SampleLivePointCounts Bases: :py:obj:`NamedTuple` .. py:attribute:: samples_indices :type: jaxns.internals.types.IntArray .. py:attribute:: num_live_points :type: jaxns.internals.types.IntArray .. py:function:: count_crossed_edges(sample_tree, num_samples = None) .. py:function:: count_crossed_edges_less_fast(S) .. py:function:: count_intervals_naive(S) .. py:function:: fast_perfect_live_point_computation_jax(log_L_constraints, log_L_samples, num_samples = None) .. py:function:: compute_num_live_points_from_unit_threads(log_L_constraints, log_L_samples, num_samples = None, sorted_collection = True) Compute the number of live points of shrinkage distribution, from an arbitrary list of samples with corresponding sampling constraints. :param log_L_constraints: [N] likelihood constraint that sample was uniformly sampled within :param log_L_samples: [N] likelihood of the sample :param sorted_collection: bool, whether the sample collection was already sorted. :returns: num_live_points for shrinkage distribution otherwise: num_live_points for shrinkage distribution, and sort indicies :rtype: if sorted_collection is true .. py:function:: count_old(S) .. py:function:: plot_tree(S) Plots the tree where x-position is log_L and y-position is a unique integer for each branch such that no edges cross. The y-position should be the same as it's sender's y-position, unless that would make an edge cross, in which case, an addition should be made so that no edges cross. e.g. For the tree: S = SampleTree( sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), log_L=jnp.asarray([-jnp.inf, 1, 2, 3, 4, 5, 6]) ) The root node connects to nodes 1, 2, 3, and then it's straight lines from 1, 2, 3 to 4, 5, 6. The ASCII plot is: 0 -- 1 -- 4 \-- 2 -- 5 \- 3 -- 6 If we add a branch from 2 to 7 then we get: 0 --- 1 -- 4 \--- 2 -- 5 \ \ -- 7 \-- 3 -- 6 See how 2-7 edge doesn't cross any other edges. Note the in-degree of each node is 1, except the root node which has in-degree 0. The out-degree can be anything. :param S: SampleTree .. py:function:: concatenate_sample_trees(trees, num_samples = None) Concatenates a list of SampleTreeGraphs into a single SampleTreeGraph. The root nodes of each tree must be the same, as this is equivalent of adding each tree as a branch from the same root node, 0 in each tree. :param trees: list of SampleTreeGraph's :returns: a single tree