em_gmm

jaxns.samplers.multi_ellipsoid.em_gmm

Module Contents

initialize_params(key, data, n_components)[source]

Initialize the parameters of a Gaussian Mixture Model.

Parameters:
  • key – the random key

  • data – [n, d] array of data

  • n_components (int) – number of components

Returns:

[num_clusters, d] array of means

Return type:

means

e_step(data, means, covariances, log_weights, mask)[source]

Compute the responsibilities of each Gaussian for each data point.

Parameters:
  • data – [n, d] array of data

  • means – [num_clusters, d] array of means

  • covariances – [num_clusters, d, d] array of covariances

  • log_weights – [num_clusters] array of log weights

  • mask – [n] boolean array indicating which data points to use

Returns:

[num_clusters, n] array of log responsibilities

Return type:

log_responsibilities

m_step(data, log_responsibilities)[source]

Update the parameters of the Gaussian Mixture Model.

Parameters:
  • data – [n, d] array of data

  • log_responsibilities – [num_clusters, n] array of log responsibilities

Returns:

[num_clusters, d] array of means

Return type:

means

em_gmm(key, data, n_components, mask=None, n_iters=10, tol=1e-06)[source]

Fit a Gaussian Mixture Model to the data using the Expectation-Maximization algorithm.

Parameters:
  • key – the random key

  • data – [n, d] array of data

  • n_components – number of components

  • mask (Union[jax.numpy.ndarray, None]) – [n] boolean array indicating which data points to use

  • n_iters – maximum number of iterations

  • tol – convergence tolerance

Returns:

[n] array of cluster assignments params: tuple of (means, covariances, log_weights) total_iters: total number of iterations use

Return type:

cluster_id