K-Means Clustering
An unsupervised algorithm that partitions data into K distinct clusters by minimizing within-cluster variance.
What It Is
K-Means groups data points into K clusters based on distance to cluster centers (centroids). Each point belongs to the cluster whose centroid is nearest. It minimizes variance within clusters while maximizing variance between them.
K-Means is the most widely used clustering algorithm. It is fast, scales well to large datasets, and is easy to interpret.
How It Works
Algorithm Steps
- Choose K — Pick the number of clusters. Use the Elbow Method or Silhouette Score if unsure.
- Initialize centroids — Randomly select K data points as starting centroids (k-means++ improves this).
- Assign points — Compute Euclidean distance from each point to every centroid. Assign each point to the nearest centroid.
- Update centroids — Recompute each centroid as the mean of all points assigned to that cluster.
- Repeat — Iterate steps 3-4 until centroids stabilize or max iterations is reached.
- Output — K final clusters with each point assigned to one group.
Distance Formula
Euclidean Distance = sqrt( (x1 - x2)^2 + (y1 - y2)^2 + ... + (xn - yn)^2 )
Code: Generate Data and Cluster
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs
# Generate synthetic data with 3 clusters
X, y = make_blobs(n_samples=300, centers=3, cluster_std=1.05, random_state=42)
# Scatter plot of the dataset
plt.scatter(X[:, 0], X[:, 1], s=50, alpha=0.6)
plt.title("Generated Dataset for K-Means Clustering")
plt.show()
Code: Fit K-Means
# Apply K-Means clustering with K=3
kmeans = KMeans(n_clusters=3, init='k-means++', max_iter=300, random_state=42)
y_kmeans = kmeans.fit_predict(X)
# Get cluster centers
centroids = kmeans.cluster_centers_
print(centroids)
Code: Visualize Clusters
# Visualizing the clustered data
plt.scatter(X[y_kmeans == 0, 0], X[y_kmeans == 0, 1], s=50, c='red', label='Cluster 1')
plt.scatter(X[y_kmeans == 1, 0], X[y_kmeans == 1, 1], s=50, c='blue', label='Cluster 2')
plt.scatter(X[y_kmeans == 2, 0], X[y_kmeans == 2, 1], s=50, c='green', label='Cluster 3')
# Plot cluster centroids
plt.scatter(centroids[:, 0], centroids[:, 1], s=200, c='yellow', marker='X',
edgecolors='black', label='Centroids')
plt.legend()
plt.title("K-Means Clustering Results")
plt.show()
Code: Predict New Points
# New data points to classify
new_points = np.array([[2, 3], [-4, 7], [6, -2]])
# Predict the cluster for new data points
predicted_clusters = kmeans.predict(new_points)
# Print results
for i, point in enumerate(new_points):
print(f"Point {point} belongs to Cluster {predicted_clusters[i]}")
# Plot with new points
plt.scatter(X[y_kmeans == 0, 0], X[y_kmeans == 0, 1], s=50, c='red', label='Cluster 1')
plt.scatter(X[y_kmeans == 1, 0], X[y_kmeans == 1, 1], s=50, c='blue', label='Cluster 2')
plt.scatter(X[y_kmeans == 2, 0], X[y_kmeans == 2, 1], s=50, c='green', label='Cluster 3')
plt.scatter(centroids[:, 0], centroids[:, 1], s=200, c='yellow', marker='X',
edgecolors='black', label='Centroids')
plt.scatter(new_points[:, 0], new_points[:, 1], s=150, c='purple', marker='D', label='New Points')
plt.legend()
plt.title("K-Means Clustering with Predictions")
plt.show()
When to Use K-Means
| Good For | Not Ideal For |
| Roughly spherical, evenly-sized clusters | Non-spherical or irregular-shaped clusters |
| Large datasets (scales well) | Clusters with very different sizes/densities |
| When K is known or easily estimated | When the number of clusters is unknown |
| Customer segmentation, image compression | Data with many outliers (sensitive to them) |
K-Means requires you to specify K upfront. Choosing the wrong K gives bad results. Always use the Elbow Method (plot inertia vs K) or Silhouette Score to find the right value.
Key Parameters
- n_clusters — Number of clusters (K)
- init — Centroid initialization method. Use 'k-means++' (default) for smarter starting positions
- max_iter — Maximum iterations before stopping (default 300)
- random_state — Seed for reproducibility
Unsupervised Clustering Centroid-based sklearn