Source code for jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils

from typing import NamedTuple, Tuple, Literal

import numpy as np
import pylab as plt
from jax import numpy as jnp, vmap, random, tree_map, lax
from jax._src.scipy.special import gammaln

from jaxns.internals.log_semiring import LogSpace
from jaxns.samplers.multi_ellipsoid.em_gmm import em_gmm
from jaxns.internals.types import IntArray, FloatArray, PRNGKey, BoolArray
from jaxns.internals.types import UType, int_type, float_type

__all__ = [
    'ellipsoid_clustering',
    'sample_multi_ellipsoid',
    'MultEllipsoidState',
]


class EllipsoidParams(NamedTuple):
    mu: FloatArray  # [K, D] Ellipsoids centres
    radii: FloatArray  # [K, D] Ellsipoids radii
    rotation: FloatArray  # [K, D, D] Ellipsoids rotation matrices


[docs] class MultEllipsoidState(NamedTuple):
[docs] params: EllipsoidParams
[docs] cluster_id: IntArray # [N] the cluster index of each point
def log_ellipsoid_volume(radii): D = radii.shape[0] return jnp.log(2.) - jnp.log(D) + 0.5 * D * jnp.log(jnp.pi) - gammaln(0.5 * D) + jnp.sum(jnp.log(radii)) def bounding_ellipsoid(points: UType, mask: FloatArray) -> Tuple[FloatArray, FloatArray]: """ Use empirical mean and covariance as approximation to bounding ellipse. Args: points: [N, D] points to fit ellipsoids to mask: [N] mask of which points to consider Returns: mu, cov """ mu = jnp.average(points, weights=mask, axis=0) dx = points - mu cov = jnp.average(dx[:, :, None] * dx[:, None, :], weights=mask, axis=0) return mu, cov def covariance_to_rotational(cov: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """ (x - mu)^T inv(cov) (x - mu) = (x - mu)^T J @ J.T (x - mu) where J.T is composed of un-rotation and un-scaling: J.T = diag(1/radii) @ rotation.T <==> J = rotation @ diag(1/radii) Now since, cov = U @ diag(s) @ V.H we have J @ J.T = inv(U @ diag(s) @ V.H) = V @ diag(1/s) @ U.H ==> J.T = diag(1/sqrt(s)) @ U.H ==> radii = sqrt(s), rotation = U Args: cov: Returns: radii, rotation """ u, s, vh = jnp.linalg.svd(cov) radii_min = jnp.finfo(s.dtype).eps radii = jnp.maximum(jnp.sqrt(s), radii_min) rotation = u return radii, rotation def ellipsoid_params(points: UType, mask: FloatArray) -> EllipsoidParams: """ If the ellipsoid is defined by (x - mu)^T C (x - mu) = 1 where C = L @ L.T and L = diag(1/radii) @ rotation.T then this returns the mu, radius and rotation matrices of the ellipsoid. Args: points: [N, D] points to fit ellipsoids to mask: [N] mask of which points to consider Returns: mu [D], radii [D] rotation [D,D] """ # get ellipsoid mean and covariance mu, Sigma = bounding_ellipsoid(points=points, mask=mask) radii, rotation = covariance_to_rotational(Sigma) # Compute scale factor for radii to enclose all points. # for all i (points[i] - mu) @ inv(Sigma) / scale**2 @ (points[i] - mu) <= 1 # for all i (points[i] - mu) @ (L @ L.T) @ (points[i] - mu) <= scale**2 rho = vmap(lambda x: maha_ellipsoid(x=x, mu=mu, radii=radii, rotation=rotation))(points) rho_max = jnp.max(jnp.where(mask, rho, 0.)) radii *= jnp.sqrt(rho_max) return EllipsoidParams(mu=mu, radii=radii, rotation=rotation) def ellipsoid_to_circle(point: FloatArray, mu: FloatArray, radii: FloatArray, rotation: FloatArray) -> FloatArray: """ Apply a linear map that would turn an ellipsoid into a sphere. Args: point: [D] point to transform mu: [D] center of ellipse radii: [D] radii of ellipse rotation: [D,D] rotation matrix of ellipse Returns: a transformed point of shape [D] """ return jnp.diag(jnp.reciprocal(radii)) @ rotation.T @ (point - mu) def circle_to_ellipsoid(point: FloatArray, mu: FloatArray, radii: FloatArray, rotation: FloatArray) -> FloatArray: """ Apple a linear map that would turn a sphere into an ellipsoid Args: point: [D] point to transform mu: [D] center of ellipse radii: [D] radii of ellipse rotation: [D,D] rotation matrix of ellipse Returns: a transformed point of shape [D] """ return mu + (rotation @ jnp.diag(radii) @ point) def maha_ellipsoid(x: FloatArray, mu: FloatArray, radii: FloatArray, rotation: FloatArray) -> FloatArray: """ Compute the Mahalanobis distance. Args: x: point [D] mu: center of ellipse [D] radii: radii of ellipse [D] rotation: rotation matrix [D, D] Returns: The Mahalanobis distance of `x` to `mu`. """ u_circ = ellipsoid_to_circle(x, mu, radii, rotation) return u_circ @ u_circ def point_in_ellipsoid(x: FloatArray, mu: FloatArray, radii: FloatArray, rotation: FloatArray) -> BoolArray: """ Determine if a given point is inside a closed ellipse. Args: x: point [D] mu: center of ellipse [D] radii: radii of ellipse [D] rotation: rotation matrix [D, D] Returns: True iff x is inside the closed ellipse """ return jnp.less_equal(maha_ellipsoid(x, mu, radii, rotation), jnp.asarray(1., x.dtype)) def sample_ellipsoid(key: PRNGKey, mu: FloatArray, radii: FloatArray, rotation: FloatArray, unit_cube_constraint: bool = False) -> FloatArray: """ Sample uniformly inside an ellipsoid. When unit_cube_constraint=True then reject points outside unit-cube. Args: key: mu: [D] radii: [D] rotation: [D,D] unit_cube_constraint: whether to restrict to the closed unit-cube. Returns: i.i.d. sample from ellipsoid of shape [D] """ def _single_sample(key): direction_key, radii_key = random.split(key, 2) direction = random.normal(direction_key, shape=radii.shape) direction = direction / jnp.linalg.norm(direction) t = random.uniform(radii_key) ** (1. / radii.size) u_circ = direction * t R = rotation * radii u = R @ u_circ + mu return u def body(state): (key, _, _) = state key, sample_key = random.split(key, 2) u = _single_sample(sample_key) done = jnp.all((u <= 1) & (u >= 0)) return (key, done, u) if unit_cube_constraint: (_, _, u) = lax.while_loop(lambda s: ~s[1], body, (key, jnp.asarray(False), mu)) else: u = _single_sample(key) return u def compute_depth_ellipsoids(point: FloatArray, mu: FloatArray, radii: FloatArray, rotation: FloatArray, constraint_unit_cube: bool = False) -> IntArray: """ Compute overlap of ellipsoids at point. Points outside the domain are given infinite depth. Args: point: [D] point to compute depth at. mu: [K, D] means of ellispoids radii: [K, D] radii of ellipsoids rotation: [K, D, D] rotation matrices of ellipsoids constraint_unit_cube: bool, whether domain is clipped to closed unit-cube. Returns: scalar representing overlap of ellipsoids. """ # in any of the ellipsoids contained_in = vmap(lambda mu, radii, rotation: point_in_ellipsoid(point, mu, radii, rotation))(mu, radii, rotation) depth = jnp.sum(contained_in) if constraint_unit_cube: # outside cube outside_unit_cube = jnp.any(point < 0.) | jnp.any(point > 1.) depth = jnp.where(outside_unit_cube, jnp.iinfo(depth.dtype).max, depth) return depth
[docs] def sample_multi_ellipsoid(key: PRNGKey, mu: FloatArray, radii: FloatArray, rotation: FloatArray, unit_cube_constraint: bool = True) -> Tuple[IntArray, FloatArray]: """ Sample from a set of intersecting ellipsoids. When unit_cube_constraint=True then reject points outside the closed unit-cube. Args: key: PRNGKey mu: [K, D] centres of ellipses radii: [K, D] radii of ellipses rotation: [K,D,D] rotation matrices of ellipses Returns: ellipsoid selected, and a sample point i.i.d. sampled from union of ellipsoids, of shape [D] """ # u(t) = R @ (x + t * num_options) + c # u(t) == 1 # 1-c = R@x + t * R@num_options # t = ((1 - c) - R@x)/R@num_options K, D = radii.shape log_VE = vmap(log_ellipsoid_volume)(radii) log_p = log_VE # - logsumexp(log_VE) def body(state): (i, _, key, done, _) = state key, accept_key, sample_key, select_key = random.split(key, 4) k = random.categorical(select_key, log_p) mu_k = mu[k, :] radii_k = radii[k, :] rotation_k = rotation[k, :, :] u_test = sample_ellipsoid(sample_key, mu_k, radii_k, rotation_k, unit_cube_constraint=False) depth = compute_depth_ellipsoids(u_test, mu, radii, rotation, constraint_unit_cube=unit_cube_constraint) done = random.uniform(accept_key) < jnp.reciprocal(depth) return (i + 1, k, key, done, u_test) _, k, _, _, u_accept = lax.while_loop(lambda state: ~state[3], body, (jnp.array(0), jnp.array(0), key, jnp.array(False), jnp.zeros(D))) return k, u_accept
def log_coverage_scale(log_VE, log_VS, D): """ Computes the required scaling relation such that V(E) = max(V(E), V(S)) where the scaling is to be applied to each radius. Args: log_VE: log_VS: D: Returns: """ return jnp.maximum(0., (log_VS - log_VE) / D) class ClusterSplitResult(NamedTuple): unsorted_cluster_id: IntArray # unsorted cluster id, using 0/1 to indicate the child. log_VS0: FloatArray params0: EllipsoidParams log_VS1: FloatArray params1: EllipsoidParams successful_split: BoolArray def _multinest_split(key: PRNGKey, params: EllipsoidParams, points: FloatArray, mask: BoolArray, log_VS: FloatArray, em_init: bool = False, patience: int = 1): """ Use's Multinest's method to partition points Args: key: PRNGKey params: ellipsoid params of points that are being split (same as used for log VE) points: [N, D] points to split mask: [N] mask only those points which should be split log_VS: estimate of logV(S) of the set of points em_init: whether to use kmeans to initialise the clustering patience: how long to wait before seeing improvement Returns: cluster_id, log_VS0, params0, log_VS1, params1 """ init_key, volume_key = random.split(key, 2) N, D = points.shape n_S = jnp.sum(mask) # calculate bounding ellipsoid ### # input is essentially log_VS if em_init: # do Euclidean kmean clustering cluster_id, (_, _, _), _ = em_gmm( key=init_key, data=points, mask=mask, n_components=2, n_iters=100 ) else: # Split the ellipsoid in half j_max = jnp.argmax(params.radii) n = jnp.where(jnp.arange(params.radii.size) == j_max, jnp.asarray(1., float_type), jnp.asarray(0., float_type) ) p = params.rotation @ (jnp.diag(params.radii) @ n) q = points - params.mu proj = q @ p cluster_id = jnp.where(proj >= jnp.asarray(0., float_type), jnp.asarray(0, int_type), jnp.asarray(1, int_type)) # # assign to random clusters: child0 or child1 # cluster_id = random.randint(init_key, shape=(N,), minval=0, maxval=2) class CarryState(NamedTuple): iter: IntArray done: BoolArray cluster_id: IntArray log_VS0: FloatArray params0: EllipsoidParams log_VS1: FloatArray params1: EllipsoidParams min_loss: FloatArray iters_no_improvement: IntArray def body(body_state: CarryState): mask0 = mask & (body_state.cluster_id == 0) mask1 = mask & (body_state.cluster_id == 1) # estimate volumes of current clustering n0 = jnp.sum(mask0) n1 = jnp.sum(mask1) log_VS0 = log_VS + jnp.log(n0) - jnp.log(n_S) log_VS1 = log_VS + jnp.log(n1) - jnp.log(n_S) # construct E_1, E_2 and compute volumes params0 = ellipsoid_params(points=points, mask=mask0) log_VE0 = log_ellipsoid_volume(params0.radii) params1 = ellipsoid_params(points=points, mask=mask1) log_VE1 = log_ellipsoid_volume(params1.radii) # enlarge to at least cover V(S1) and V(S2) log_scale0 = log_coverage_scale(log_VE0, log_VS0, D) log_scale1 = log_coverage_scale(log_VE1, log_VS1, D) radii0 = jnp.exp(jnp.log(params0.radii) + log_scale0) radii1 = jnp.exp(jnp.log(params1.radii) + log_scale1) log_VE0 = log_VE0 + log_scale0 * D log_VE1 = log_VE1 + log_scale1 * D params0 = params0._replace(radii=radii0) params1 = params1._replace(radii=radii1) # compute reassignment metrics maha0 = vmap(lambda point: maha_ellipsoid(point, mu=params0.mu, radii=params0.radii, rotation=params0.radii))(points) maha1 = vmap(lambda point: maha_ellipsoid(point, mu=params1.mu, radii=params1.radii, rotation=params1.radii))(points) h0 = LogSpace(log_VE0) * LogSpace(jnp.log(maha0)) / LogSpace(log_VS0) h1 = LogSpace(log_VE1) * LogSpace(jnp.log(maha1)) / LogSpace(log_VS1) # reassign biggest violator abs_delta_F = (h0 - h1).abs() # N masked_log_abs_delta_F = jnp.where(mask, abs_delta_F.log_abs_val, -jnp.inf) reassign_idx = jnp.argmax(masked_log_abs_delta_F) new_id = jnp.where(masked_log_abs_delta_F[reassign_idx] > -jnp.inf, jnp.asarray(1, int_type) - cluster_id[reassign_idx], cluster_id[reassign_idx]) new_cluster_id = cluster_id.at[reassign_idx].set(new_id) # new_cluster_id = jnp.where(mask & (h0.log_abs_val < h1.log_abs_val), # jnp.asarray(0., int_type), # jnp.asarray(1., int_type)) # new_cluster_k = jnp.where(log_h1 < log_h2, 0, 1) log_V_sum = jnp.logaddexp(log_VE0, log_VE1) new_loss = log_V_sum - log_VS # If scaling happened, then this will be zero loss_decreased = new_loss < body_state.min_loss iters_no_improvement = jnp.where(loss_decreased, 0, body_state.iters_no_improvement + 1) min_loss = jnp.where(loss_decreased, new_loss, body_state.min_loss) ### # i / delay / loss_decreased / new_loss / min_loss # 0 / 0 / True / a / a # 1 / 1 / False / b / a # 2 / 2 / False / a / a # 3 / 3 / False / b / a # 4 / 4 / False / a / a cluster_mapping_unchanged = jnp.all((new_cluster_id == body_state.cluster_id) | jnp.bitwise_not(mask)) done = cluster_mapping_unchanged \ | (iters_no_improvement >= patience) \ | (n0 < D + 1) \ | (n1 < D + 1) \ | jnp.isnan(log_V_sum) return CarryState( iter=body_state.iter + jnp.asarray(1, int_type), done=done, cluster_id=new_cluster_id, log_VS0=log_VS0, params0=params0, log_VS1=log_VS1, params1=params1, min_loss=min_loss, iters_no_improvement=iters_no_improvement ) # Done to start with if not at least D+1 points per ellipsoid possible done = (n_S < 2 * (D + 1)) init_state = CarryState( iter=jnp.array(0), done=done, cluster_id=cluster_id, log_VS0=jnp.array(-jnp.inf), params0=EllipsoidParams(mu=jnp.zeros(D), radii=jnp.zeros(D), rotation=jnp.eye(D)), log_VS1=jnp.array(-jnp.inf), params1=EllipsoidParams(mu=jnp.zeros(D), radii=jnp.zeros(D), rotation=jnp.eye(D)), min_loss=jnp.asarray(jnp.inf), iters_no_improvement=jnp.asarray(0, int_type) ) output_state: CarryState = lax.while_loop(lambda state: ~state.done, body, init_state) return output_state.cluster_id, output_state.log_VS0, output_state.params0, output_state.log_VS1, output_state.params1 def _em_gmm_split(key: PRNGKey, points: FloatArray, mask: BoolArray, log_VS: FloatArray): """ Use's EM Gaussian mixture model to partition points. Args: key: PRNGKey points: [N, D] points to split mask: [N] mask only those points which should be split log_VS: estimate of logV(S) of the set of points Returns: cluster_id, log_VS0, params0, log_VS1, params1 """ N, D = points.shape n_S = jnp.sum(mask) # do Euclidean kmean clustering cluster_id, (_, _, _), _ = em_gmm( key=key, data=points, mask=mask, n_components=2, n_iters=100 ) mask0 = mask & (cluster_id == 0) mask1 = mask & (cluster_id == 1) # estimate volumes of current clustering n0 = jnp.sum(mask0) n1 = jnp.sum(mask1) log_VS0 = log_VS + jnp.log(n0) - jnp.log(n_S) log_VS1 = log_VS + jnp.log(n1) - jnp.log(n_S) # construct E_1, E_2 and compute volumes params0 = ellipsoid_params(points=points, mask=mask0) log_VE0 = log_ellipsoid_volume(params0.radii) params1 = ellipsoid_params(points=points, mask=mask1) log_VE1 = log_ellipsoid_volume(params1.radii) # enlarge to at least cover V(S1) and V(S2) log_scale0 = log_coverage_scale(log_VE0, log_VS0, D) log_scale1 = log_coverage_scale(log_VE1, log_VS1, D) radii0 = jnp.exp(jnp.log(params0.radii) + log_scale0) radii1 = jnp.exp(jnp.log(params1.radii) + log_scale1) params0 = params0._replace(radii=radii0) params1 = params1._replace(radii=radii1) return cluster_id, log_VS0, params0, log_VS1, params1 def cluster_split(key: PRNGKey, params: EllipsoidParams, points: FloatArray, mask: BoolArray, log_VS: FloatArray, method: Literal['multinest', 'em_gmm']) -> ClusterSplitResult: """ Splits a set of points into two ellipsoids such that the enclosed volume is as close to V(S) without being less. V(S) should be an estimate of the true volume contained by the points. Args: key: PRNGKey params: ellipsoid params of points that are being split (same as used for log VE) points: [N, D] points to split mask: [N] mask only those points which should be split log_VS: estimate of logV(S) of the set of points method: what method to use for splitting. Available are: 'multinest','em_gmm' Returns: cluster split results """ N, D = points.shape # calculate bounding ellipsoid # volume of ellipsoid, already have E scaled so that V(E) >= V(S) log_VE = log_ellipsoid_volume(params.radii) # We always have if method == 'em_gmm': cluster_id, log_VS0, params0, log_VS1, params1 = _em_gmm_split(key=key, points=points, mask=mask, log_VS=log_VS) elif method == 'multinest': cluster_id, log_VS0, params0, log_VS1, params1 = _multinest_split(key=key, params=params, points=points, mask=mask, log_VS=log_VS, em_init=False, patience=1) else: raise ValueError(f"Invalid method {method}") # Imperfect sampling condition # (0) V(A) <= V(S1), and V(B) <= V(S2) # V(A) <= V(S1) = V(S) V(A) / (V(A) + V(B)) # V(A) (V(A) + V(B)) <= V(S) V(A) # V(A)^2 + V(A)V(B) - V(S) V(A) <= 0 # V(B)^2 + V(A)V(B) - V(S) V(B) <= 0 # V(A)^2 - V(B)^2 - V(S) (V(A) - V(B)) <= 0 # (V(A) + V(B)) (V(A) - V(B)) - V(S) (V(A) - V(B)) <= 0 # (V(A) + V(B)) - V(S) <= 0, (V(A) - V(B)) != 0 # # Bounding ellipsoid condition # (1) V(S1) <= V(E_A), and V(S2) <= V(E_B) # (1.1) from (0) => V(A) + V(B) <= V(S1) + V(S2) <= V(E_A) + V(E_B) # Disjoint partitioning condition # (2) V(S1) + V(S2) = V(S) # (2.1) from (1.1) => V(A) + V(B) <= V(S1) + V(S2) = V(S) <= V(E_A) + V(E_B) # (2.2) from (2) => V(S1 ^ S2) = 0 # Disjoint partitioning condition # (3) V(A) + V(B) = V(A v B) # (3.1) from (2) => V(A ^ B) = 0 # Bounding ellipsoid condition # (4) V(S) <= V(E) # (4.1) with (2.1) => V(A) + V(B) <= V(S1) + V(S2) = V(S) <= V(E) # (5) Good split <=> V(S1) ~ V(E_A), and V(S2) ~ V(E_B) # (5.1) from (4.1) Good split <=> V(A) + V(B) <= V(S1) + V(S2) = V(S) <~ V(E_B) + V(E_B) <= V(E) # (6) Good sampling <=> V(A) ~ V(S1), and V(B) ~ V(S2) # (6.1) from (5.1) Good sampling <=> V(A) + V(B) <~ V(S1) + V(S2) = V(S) <= V(E) # (7) Good split and Good sampling <=> V(A) + V(B) <~ V(S1) + V(S2) = V(S) <~ V(E_B) + V(E_B) <= V(E) # We take the condition for success: # V(E_B) + V(E_B) is closer to V(S) than V(E) (from 5.1) mask0 = mask & (cluster_id == 0) mask1 = mask & (cluster_id == 1) log_VE_A = log_ellipsoid_volume(params0.radii) log_VE_B = log_ellipsoid_volume(params1.radii) V_sum = LogSpace(log_VE_A) + LogSpace(log_VE_B) good_split = (V_sum.log_abs_val < log_VE) successful_split = good_split \ & jnp.bitwise_not(jnp.isnan(log_VE_A)) \ & jnp.bitwise_not(jnp.isnan(log_VE_A)) \ & (jnp.sum(mask0) >= (D + 1)) \ & (jnp.sum(mask1) >= (D + 1)) return ClusterSplitResult(unsorted_cluster_id=cluster_id, log_VS0=log_VS0, params0=params0, log_VS1=log_VS1, params1=params1, successful_split=successful_split) def plot_ellipses(params: EllipsoidParams, show: bool = True): """ Plots ellipses. Args: params: ellipsoid parameters to plot show: whether to show figure """ theta = jnp.linspace(0., 2 * jnp.pi, 100) circle = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) for mu, radii, rotation in zip(params.mu, params.radii, params.rotation): ellipse = vmap(lambda point: circle_to_ellipsoid(point, mu, radii, rotation))(circle) plt.plot(ellipse[:, 0], ellipse[:, 1], c=np.random.uniform(size=3)) if show: plt.show()
[docs] def ellipsoid_clustering(key: PRNGKey, points: FloatArray, log_VS: FloatArray, max_num_ellipsoids: int, method: Literal['multinest', 'em_gmm'] = 'em_gmm') -> MultEllipsoidState: """ Partition live_points into 2^depth ellipsoids in depth-first order. Args: key:PRNGKey points: [N, D] points to partition log_VS: expected true volume of points max_num_ellipsoids: the maximum number of ellipsoids Returns: params of multi-ellipsoids and cluster id of points """ N, D = points.shape if max_num_ellipsoids < 1: raise ValueError(f"max_num_ellipsoids should be >= 1, got {max_num_ellipsoids}.") K = max_num_ellipsoids # Construct the initial state init_ellipsoid = ellipsoid_params(points=points, mask=jnp.ones(N, jnp.bool_)) log_VE = log_ellipsoid_volume(init_ellipsoid.radii) log_scale = log_coverage_scale(log_VE, log_VS, D) radii = jnp.exp(jnp.log(init_ellipsoid.radii) + log_scale) # log_VE = log_VE + log_scale * D init_ellipsoid = init_ellipsoid._replace(radii=radii) # state is zeros except first ellipsoid cluster_id = jnp.zeros(N, dtype=int_type) params = EllipsoidParams( mu=jnp.zeros((K, D), float_type), radii=jnp.zeros((K, D), float_type), rotation=jnp.zeros((K, D, D), float_type) ) params: EllipsoidParams = tree_map(lambda x, y: x.at[0].set(y), params, init_ellipsoid) state = MultEllipsoidState( cluster_id=cluster_id, params=params ) # Initial tracking parameters log_VS_subclusters = jnp.asarray([log_VS] + [-jnp.inf] * (K - 1)) done_splitting = jnp.isneginf(log_VS_subclusters) split_depth = jnp.zeros([K], int_type) # TODO: compare performance with scan class CarryType(NamedTuple): key: PRNGKey next_k: IntArray state: MultEllipsoidState done_splitting: BoolArray split_depth: IntArray log_VS_subclusters: FloatArray def body(body_state: CarryType) -> CarryType: key, split_key = random.split(body_state.key, 2) # Select the depth we work on now: bread first selection ==> min depth first (excluding done splits) select_split = jnp.argmin( jnp.where(body_state.done_splitting, jnp.iinfo(body_state.split_depth.dtype).max, body_state.split_depth) ) mask = body_state.state.cluster_id == select_split # estimated volume in sub-cluster log_VS = body_state.log_VS_subclusters[select_split] # params of ellipsoid params = tree_map(lambda x: x[select_split], body_state.state.params) # Perform a split on points in the given mask # Strategy: if no split we replace child0 with parent and child1 gets zero-size ellipsoid that has no members. cluster_split_result: ClusterSplitResult = cluster_split( key=split_key, params=params, points=points, mask=mask, log_VS=log_VS, method=method ) # Update the parameters in given component that is being split with child 0 params = tree_map(lambda x, y: jnp.where(cluster_split_result.successful_split, x.at[select_split].set(y), x), body_state.state.params, cluster_split_result.params0) # Update the parameters in `next_k` with child 1 params = tree_map( lambda x, y: jnp.where(cluster_split_result.successful_split, x.at[body_state.next_k].set(y), x), params, cluster_split_result.params1) # select_split stays the same cluster_id taking on child 0, but next_k gets child 1 cluster_id = jnp.where( cluster_split_result.successful_split & (cluster_split_result.unsorted_cluster_id == 1) & mask, body_state.next_k, body_state.state.cluster_id) state = body_state.state._replace(params=params, cluster_id=cluster_id) # If success => next_k is not done, (and select_k is not done, as previously set) # Else select_k is done (next_k stays done, as previously set) done_splitting = jnp.where(cluster_split_result.successful_split, body_state.done_splitting.at[body_state.next_k].set(False), body_state.done_splitting.at[select_split].set(True)) # If success => update split depth new_depth = body_state.split_depth[select_split] + jnp.asarray(1, int_type) split_depth = jnp.where(cluster_split_result.successful_split, body_state.split_depth.at[select_split].set(new_depth), body_state.split_depth) split_depth = jnp.where(cluster_split_result.successful_split, split_depth.at[body_state.next_k].set(new_depth), split_depth) # If success => update estimated subcluster volumes log_VS_subclusters = jnp.where(cluster_split_result.successful_split, body_state.log_VS_subclusters.at[select_split].set(cluster_split_result.log_VS0), body_state.log_VS_subclusters) log_VS_subclusters = jnp.where(cluster_split_result.successful_split, log_VS_subclusters.at[body_state.next_k].set(cluster_split_result.log_VS1), log_VS_subclusters) # TODO: (verify) I think next_k should only increment if successful split, as otherwise it uses up space. next_k = jnp.where(cluster_split_result.successful_split, body_state.next_k + jnp.asarray(1, int_type), body_state.next_k) return CarryType( key=key, next_k=next_k, state=state, done_splitting=done_splitting, split_depth=split_depth, log_VS_subclusters=log_VS_subclusters ) def cond(body_state: CarryType) -> BoolArray: done = jnp.all(body_state.done_splitting) | (body_state.next_k == K) return jnp.bitwise_not(done) init_body_state = CarryType( key=key, next_k=jnp.asarray(1, int_type), state=state, done_splitting=done_splitting, split_depth=split_depth, log_VS_subclusters=log_VS_subclusters ) output_state = lax.while_loop(cond, body, init_body_state) return output_state.state