ML Playground / GMM Clustering View Notebook

Gaussian Mixture Model (GMM)

A probabilistic clustering algorithm that models data as a mixture of multiple Gaussian distributions, providing soft cluster assignments.

What It Is

GMM assumes data is generated from a mixture of several Gaussian (normal) distributions, each representing a cluster. Unlike K-Means which assigns each point to exactly one cluster, GMM gives a probability of belonging to each cluster (soft clustering).

GMM is the go-to algorithm when clusters overlap or when you need probability estimates for cluster membership rather than hard assignments.

How It Works

GMM uses the Expectation-Maximization (EM) algorithm to iteratively estimate the parameters of each Gaussian distribution.

Algorithm Steps
  1. Initialize parameters — Set the number of clusters K. Randomly initialize the mean, covariance, and mixing weight for each Gaussian.
  2. E-Step (Expectation) — For each data point, compute the probability it belongs to each Gaussian cluster using the multivariate Gaussian PDF.
  3. M-Step (Maximization) — Update the mean, covariance, and weight of each Gaussian to maximize the likelihood of the observed data.
  4. Repeat until parameters converge (stop changing significantly).

Parameters of Each Gaussian

Each Gaussian cluster has three parameters: Mean (mu) = center of the cluster Covariance (Sigma) = shape and spread of the cluster Weight (pi) = proportion of data belonging to this cluster Sum of all weights = 1

Code: Fit GMM

import numpy as np import matplotlib.pyplot as plt from sklearn.mixture import GaussianMixture from sklearn.datasets import make_blobs from sklearn.preprocessing import StandardScaler # Generate synthetic data with 3 clusters X, _ = make_blobs(n_samples=300, centers=3, cluster_std=1.2, random_state=42) # Standardize features (important for GMM) scaler = StandardScaler() X_scaled = scaler.fit_transform(X) # Plot raw data plt.scatter(X_scaled[:, 0], X_scaled[:, 1], s=30, alpha=0.6) plt.title("Generated Data for Clustering") plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.show()

Train and Predict

# Define and fit the GMM model gmm = GaussianMixture(n_components=3, covariance_type='full', random_state=42) gmm.fit(X_scaled) # Predict cluster labels (hard assignment) labels = gmm.predict(X_scaled) # Get soft probabilities for each cluster probabilities = gmm.predict_proba(X_scaled) # Plot clusters with GMM labels plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=labels, cmap='viridis', s=40, alpha=0.7) plt.scatter(gmm.means_[:, 0], gmm.means_[:, 1], c='red', marker='X', s=200, label='Centroids') plt.title("GMM Clustering Results") plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.legend() plt.show()

Inspect Learned Parameters

print("Cluster Weights (pi):", gmm.weights_) print("Cluster Means (mu):", gmm.means_) print("Cluster Covariances (Sigma):", gmm.covariances_)

K-Means vs GMM

FeatureK-MeansGMM
AssignmentHard (one cluster per point)Soft (probability per cluster)
Cluster shapeSpherical onlyElliptical (any covariance shape)
AlgorithmLloyd's iterationExpectation-Maximization
SpeedFasterSlower (more parameters to estimate)
Outlier handlingAssigns outliers to nearest clusterOutliers get low probability for all clusters

When to Use GMM

Good ForNot Ideal For
Overlapping clustersVery high-dimensional data (many parameters)
Need probability estimates, not just labelsWhen clusters are well-separated (K-Means is simpler)
Elliptical/non-spherical cluster shapesVery large datasets (EM can be slow)
Density estimation, anomaly detectionWhen you need deterministic results

GMM is sensitive to initialization and can converge to local optima. Always use multiple initializations (n_init parameter) and pick the result with the highest likelihood. Standardize your features before fitting.

Key Parameters

Unsupervised Clustering Probabilistic EM Algorithm