em_gmm ================ .. py:module:: jaxns.samplers.multi_ellipsoid.em_gmm .. rubric:: :code:`jaxns.samplers.multi_ellipsoid.em_gmm` .. rubric:: Module Contents .. py:function:: initialize_params(key, data, n_components) Initialize the parameters of a Gaussian Mixture Model. :param key: the random key :param data: [n, d] array of data :param n_components: number of components :returns: [num_clusters, d] array of means :rtype: means .. py:function:: e_step(data, means, covariances, log_weights, mask) Compute the responsibilities of each Gaussian for each data point. :param data: [n, d] array of data :param means: [num_clusters, d] array of means :param covariances: [num_clusters, d, d] array of covariances :param log_weights: [num_clusters] array of log weights :param mask: [n] boolean array indicating which data points to use :returns: [num_clusters, n] array of log responsibilities :rtype: log_responsibilities .. py:function:: m_step(data, log_responsibilities) Update the parameters of the Gaussian Mixture Model. :param data: [n, d] array of data :param log_responsibilities: [num_clusters, n] array of log responsibilities :returns: [num_clusters, d] array of means :rtype: means .. py:function:: em_gmm(key, data, n_components, mask = None, n_iters=10, tol=1e-06) Fit a Gaussian Mixture Model to the data using the Expectation-Maximization algorithm. :param key: the random key :param data: [n, d] array of data :param n_components: number of components :param mask: [n] boolean array indicating which data points to use :param n_iters: maximum number of iterations :param tol: convergence tolerance :returns: [n] array of cluster assignments params: tuple of (means, covariances, log_weights) total_iters: total number of iterations use :rtype: cluster_id