Naive k-means
This short post introduces the well-known naive k-means algorithm, a proof of its convergence, and an implementation in NumPy.
Within-cluster sum of squares (WCSS)
Consider partitioning points in Euclidean space such that “nearby” points are in the same subset. To measure the quality of such a partition, we need an objective function mapping each partition to a number. The within-cluster sum of squares (WCSS) is one such possible choice of objective.
Formally, given a finite dimensional subset
where
Naive k-means
Naive k-means is an algorithm that attempts to optimize WCSS over all partitions of size
The algorithm is iterative: given a partition
where
Proposition. The k-means algorithm converges (i.e.,
Proof.
To establish convergence of k-means, it is sufficient to prove that each iteration produces an improvement:
Implementation
Below is an implementation of the algorithm.
The initial partition is constructed by sampling
View code
from typing import NamedTuple
import numpy as np
from numpy.typing import NDArray
class KMeansResult(NamedTuple):
"""Result of running k-means.
Attributes
----------
centroids
Array of shape (n_clusters, n_features)
converged
Whether or not the algorithm converged or was terminated early
n_iters
Number of iterations
"""
centroids: NDArray
converged: bool
n_iters: int
def k_means(
data: NDArray,
n_clusters: int,
generator: np.random.Generator | None = None,
max_iters: int = 1000,
tolerance: float = 1e-3,
) -> KMeansResult:
"""Runs k-means.
Parameters
----------
data
Array of shape (n_samples, n_features)
n_clusters
Number of clusters
generator
Random generator (if unspecified, `np.random.default_rng()` is used)
max_iters
Maximum number of iterations before giving up
tolerance
Convergence tolerance threshold
Returns
-------
KMeansResult object
"""
n, _ = data.shape
k = n_clusters
if n < k:
msg = f"The number of points ({n}) should be at least as large as the number of centroids ({k})"
raise RuntimeError(msg)
if generator is None:
generator = np.random.default_rng()
def init_centroids(n_points: int) -> NDArray:
# TODO: Improve by using k-means++ initialization
return data[generator.choice(n, size=(n_points,), replace=False)]
centroids = init_centroids(k) # (k, p)
prev_centroids = np.full_like(centroids, np.nan)
n_iters = 0
converged = False
while n_iters < max_iters:
if converged := ((centroids - prev_centroids) ** 2).mean() <= tolerance:
break
# For each point, find the closest centroid
squared_dists = ((data[:, np.newaxis, :] - centroids) ** 2).sum(axis=-1)
closest = np.argmin(squared_dists, axis=-1)
# Update centroids
prev_centroids = centroids
centroids = np.stack([data[closest == i].mean(axis=0) for i in range(k)])
# If a centroid has no points, re-initialize it
mask = np.any(np.isnan(centroids), axis=-1)
centroids[mask] = init_centroids(mask.sum())
n_iters += 1
return KMeansResult(centroids=centroids, converged=converged, n_iters=n_iters + 1)
Converged in 6 iterations