em_gmm
jaxns.samplers.multi_ellipsoid.em_gmm
Module Contents
- initialize_params(key, data, n_components)[source]
Initialize the parameters of a Gaussian Mixture Model.
- Parameters:
key – the random key
data – [n, d] array of data
n_components (int) – number of components
- Returns:
[num_clusters, d] array of means
- Return type:
means
- e_step(data, means, covariances, log_weights, mask)[source]
Compute the responsibilities of each Gaussian for each data point.
- Parameters:
data – [n, d] array of data
means – [num_clusters, d] array of means
covariances – [num_clusters, d, d] array of covariances
log_weights – [num_clusters] array of log weights
mask – [n] boolean array indicating which data points to use
- Returns:
[num_clusters, n] array of log responsibilities
- Return type:
log_responsibilities
- m_step(data, log_responsibilities)[source]
Update the parameters of the Gaussian Mixture Model.
- Parameters:
data – [n, d] array of data
log_responsibilities – [num_clusters, n] array of log responsibilities
- Returns:
[num_clusters, d] array of means
- Return type:
means
- em_gmm(key, data, n_components, mask=None, n_iters=10, tol=1e-06)[source]
Fit a Gaussian Mixture Model to the data using the Expectation-Maximization algorithm.
- Parameters:
key – the random key
data – [n, d] array of data
n_components – number of components
mask (Union[jax.numpy.ndarray, None]) – [n] boolean array indicating which data points to use
n_iters – maximum number of iterations
tol – convergence tolerance
- Returns:
[n] array of cluster assignments params: tuple of (means, covariances, log_weights) total_iters: total number of iterations use
- Return type:
cluster_id