Gaussian Mixture Model (GMM)
A probabilistic clustering algorithm that models data as a mixture of multiple Gaussian distributions, providing soft cluster assignments.
What It Is
GMM assumes data is generated from a mixture of several Gaussian (normal) distributions, each representing a cluster. Unlike K-Means which assigns each point to exactly one cluster, GMM gives a probability of belonging to each cluster (soft clustering).
GMM is the go-to algorithm when clusters overlap or when you need probability estimates for cluster membership rather than hard assignments.
How It Works
GMM uses the Expectation-Maximization (EM) algorithm to iteratively estimate the parameters of each Gaussian distribution.
Algorithm Steps
- Initialize parameters — Set the number of clusters K. Randomly initialize the mean, covariance, and mixing weight for each Gaussian.
- E-Step (Expectation) — For each data point, compute the probability it belongs to each Gaussian cluster using the multivariate Gaussian PDF.
- M-Step (Maximization) — Update the mean, covariance, and weight of each Gaussian to maximize the likelihood of the observed data.
- Repeat until parameters converge (stop changing significantly).
Parameters of Each Gaussian
Each Gaussian cluster has three parameters:
Mean (mu) = center of the cluster
Covariance (Sigma) = shape and spread of the cluster
Weight (pi) = proportion of data belonging to this cluster
Sum of all weights = 1
Code: Fit GMM
import numpy as np
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler
# Generate synthetic data with 3 clusters
X, _ = make_blobs(n_samples=300, centers=3, cluster_std=1.2, random_state=42)
# Standardize features (important for GMM)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Plot raw data
plt.scatter(X_scaled[:, 0], X_scaled[:, 1], s=30, alpha=0.6)
plt.title("Generated Data for Clustering")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()
Train and Predict
# Define and fit the GMM model
gmm = GaussianMixture(n_components=3, covariance_type='full', random_state=42)
gmm.fit(X_scaled)
# Predict cluster labels (hard assignment)
labels = gmm.predict(X_scaled)
# Get soft probabilities for each cluster
probabilities = gmm.predict_proba(X_scaled)
# Plot clusters with GMM labels
plt.scatter(X_scaled[:, 0], X_scaled[:, 1], c=labels, cmap='viridis', s=40, alpha=0.7)
plt.scatter(gmm.means_[:, 0], gmm.means_[:, 1], c='red', marker='X', s=200, label='Centroids')
plt.title("GMM Clustering Results")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.show()
Inspect Learned Parameters
print("Cluster Weights (pi):", gmm.weights_)
print("Cluster Means (mu):", gmm.means_)
print("Cluster Covariances (Sigma):", gmm.covariances_)
K-Means vs GMM
| Feature | K-Means | GMM |
| Assignment | Hard (one cluster per point) | Soft (probability per cluster) |
| Cluster shape | Spherical only | Elliptical (any covariance shape) |
| Algorithm | Lloyd's iteration | Expectation-Maximization |
| Speed | Faster | Slower (more parameters to estimate) |
| Outlier handling | Assigns outliers to nearest cluster | Outliers get low probability for all clusters |
When to Use GMM
| Good For | Not Ideal For |
| Overlapping clusters | Very high-dimensional data (many parameters) |
| Need probability estimates, not just labels | When clusters are well-separated (K-Means is simpler) |
| Elliptical/non-spherical cluster shapes | Very large datasets (EM can be slow) |
| Density estimation, anomaly detection | When you need deterministic results |
GMM is sensitive to initialization and can converge to local optima. Always use multiple initializations (n_init parameter) and pick the result with the highest likelihood. Standardize your features before fitting.
Key Parameters
- n_components — Number of Gaussian clusters
- covariance_type — 'full' (each cluster has its own covariance), 'tied', 'diag', 'spherical'
- n_init — Number of initializations (default 1, increase for stability)
- random_state — Seed for reproducibility
Unsupervised Clustering Probabilistic EM Algorithm