multi_ellipsoid_utils

jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils

Module Contents

class MultEllipsoidState[source]

Bases: NamedTuple

params: EllipsoidParams[source]
cluster_id: jaxns.internals.types.IntArray[source]
sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[source]

Sample from a set of intersecting ellipsoids. When unit_cube_constraint=True then reject points outside the closed unit-cube.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • mu (jaxns.internals.types.FloatArray) – [K, D] centres of ellipses

  • radii (jaxns.internals.types.FloatArray) – [K, D] radii of ellipses

  • rotation (jaxns.internals.types.FloatArray) – [K,D,D] rotation matrices of ellipses

  • unit_cube_constraint (bool) –

Returns:

ellipsoid selected, and a sample point i.i.d. sampled from union of ellipsoids, of shape [D]

Return type:

Tuple[jaxns.internals.types.IntArray, jaxns.internals.types.FloatArray]

ellipsoid_clustering(key, points, log_VS, max_num_ellipsoids, method='em_gmm')[source]

Partition live_points into 2^depth ellipsoids in depth-first order.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • points (jaxns.internals.types.FloatArray) – [N, D] points to partition

  • log_VS (jaxns.internals.types.FloatArray) – expected true volume of points

  • max_num_ellipsoids (int) – the maximum number of ellipsoids

  • method (Literal[multinest, jaxns.samplers.multi_ellipsoid.em_gmm.em_gmm]) –

Returns:

params of multi-ellipsoids and cluster id of points

Return type:

MultEllipsoidState