ML Playground / Mean Shift View Notebook

Mean Shift Clustering

A density-based algorithm that finds cluster centers by iteratively shifting toward regions of highest data density. No need to specify the number of clusters.

What It Is

Mean Shift finds clusters by locating the peaks (modes) of the data's density function. Each data point iteratively moves toward the mean of points within a window (bandwidth), converging at density peaks. Points that converge to the same peak belong to the same cluster.

Mean Shift automatically determines the number of clusters. You only need to set the bandwidth (window size), which controls the granularity of the clustering.

Key Concept: Bandwidth

The bandwidth defines the radius of the window used to compute the local mean. Small bandwidth = many small clusters. Large bandwidth = fewer, larger clusters.

How It Works

Algorithm Steps
  1. Place a window (circle of radius = bandwidth) around each data point
  2. Compute the mean of all points inside the window
  3. Shift the window center to that mean position
  4. Repeat until the center stops moving (convergence)
  5. Group points that converge to the same peak into one cluster

Code: Mean Shift Example

import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import MeanShift from sklearn.datasets import make_blobs # Generate sample data X, _ = make_blobs(n_samples=300, centers=3, cluster_std=0.8, random_state=42) # Apply Mean Shift ms = MeanShift(bandwidth=2) # bandwidth = window size labels = ms.fit_predict(X) centers = ms.cluster_centers_ print("Cluster centers:") print(centers) # Plot results plt.scatter(X[:, 0], X[:, 1], c=labels, cmap="plasma", s=30) plt.scatter(centers[:, 0], centers[:, 1], c="black", marker="x", s=200) plt.title("Mean Shift Clustering") plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.show()

Bandwidth Selection

Sklearn provides an automatic bandwidth estimator:

from sklearn.cluster import estimate_bandwidth # Automatically estimate bandwidth bandwidth = estimate_bandwidth(X, quantile=0.2) print(f"Estimated bandwidth: {bandwidth}") ms = MeanShift(bandwidth=bandwidth) labels = ms.fit_predict(X)

Mean Shift vs K-Means

FeatureK-MeansMean Shift
Number of clustersMust specify KDetermined automatically
Cluster shapeSpherical onlyArbitrary shape
ParametersK (number of clusters)Bandwidth (window size)
SpeedFast (O(nKt))Slow (O(n^2) per iteration)
Outlier handlingAssigns to nearest clusterForms tiny clusters for outliers

When to Use Mean Shift

Good ForNot Ideal For
Unknown number of clustersLarge datasets (slow, O(n^2))
Non-spherical cluster shapesHigh-dimensional data
Image segmentation, object trackingWhen speed is critical
Small to medium datasetsVery different cluster densities

Mean Shift is computationally expensive (O(n^2) per iteration). For large datasets, use K-Means or DBSCAN instead. The bandwidth parameter heavily influences results; use estimate_bandwidth() as a starting point.

Unsupervised Clustering Density-based Non-parametric