Naive k-means
This short post introduces the well-known naive k-means algorithm, a proof of its convergence, and an implementation in NumPy.
Within-cluster sum of squares (WCSS)
Consider partitioning points in Euclidean space such that “nearby” points are in the same subset. To measure the quality of such a partition, we need an objective function mapping each partition to a number. The within-cluster sum of squares (WCSS) is one such possible choice of objective.
Formally, given a finite dimensional subset $\mathcal{X}$ of Euclidean space, the WCSS of a partition $\Pi$ of these points is
\[\operatorname{WCSS}(\Pi)\equiv \sum_{\pi \in \Pi} \sum_{x \in \pi} \left\Vert x - \mathbb{E} X_\pi \right\Vert^2\]where $X_{\pi}\sim\operatorname{Uniform}(\pi)$ so that $\mathbb{E} X_\pi$ is the mean of the partition $\pi$.
Naive k-means
Naive k-means is an algorithm that attempts to optimize WCSS over all partitions of size $k$. While not guaranteed to optimize WCSS, it is guaranteed to converge to a local minimum.
The algorithm is iterative: given a partition $\Pi$, the next partition $\Pi^{\prime}$ is produced by reassigning each observation to its nearest mean. Formally,
\[\Pi^{\prime}=\left\{ \left\{ x\in\mathcal{X}\colon\tau(x)=\pi\right\} \colon\pi\in\Pi\right\}\]where
\[\tau(x)\in\operatorname{argmin}_{\pi\in\Pi}\left\Vert x-\mathbb{E}X_{\pi}\right\Vert ^{2}.\]Proposition. The k-means algorithm converges (i.e., $\Pi^\prime = \Pi$) after a finite number of steps.
Proof. To establish convergence of k-means, it is sufficient to prove that each iteration produces an improvement: $\operatorname{WCSS}(\Pi^{\prime})\leq\operatorname{WCSS}(\Pi)$. Since the WCSS is bounded below by zero and there are only finitely many partitions, the result follows. Indeed,
\[\begin{align*} \operatorname{WCSS}(\Pi) & =\sum_{\pi\in\Pi}\sum_{x\in\pi}\left\Vert x-\mathbb{E}X_{\pi}\right\Vert ^{2}\\ & \geq\sum_{\pi\in\Pi}\sum_{x\in\pi}\left\Vert x-\mathbb{E}X_{\tau(x)}\right\Vert ^{2}\\ & =\sum_{\pi\in\Pi^{\prime}}\sum_{x\in\pi}\left\Vert x-\mathbb{E}X_{\tau(x)}\right\Vert ^{2}\\ & \geq\sum_{\pi\in\Pi^{\prime}}\min_{c\in\mathbb{R}^{p}}\sum_{x\in\pi}\left\Vert x-c\right\Vert ^{2}\\ & =\sum_{\pi\in\Pi^{\prime}}\sum_{x\in\pi}\left\Vert x-\mathbb{E}X_{\pi}\right\Vert ^{2}\\ & =\operatorname{WCSS}(\Pi^{\prime}). \blacksquare \end{align*}\]Implementation
Below is an implementation of the algorithm.
The initial partition is constructed by sampling $k$ points $x_1, \ldots, x_k$ (without replacement) from $\mathcal{X}$. Each point in $\mathcal{X}$ is then assigned to the closest of $x_1, \ldots, x_k$. This choice of initialization, while easy to code, can produce poor results (e.g., consider the case in which all $k$ points are picked close to one another). The reader interested in improving initialization is referred to k-means++.
from typing import NamedTuple
import numpy as np
from numpy.typing import NDArray
class KMeansResult(NamedTuple):
"""Result of running k-means.
Attributes
----------
centroids
Array of shape (n_clusters, n_features)
converged
Whether or not the algorithm converged or was terminated early
n_iters
Number of iterations
"""
centroids: NDArray
converged: bool
n_iters: int
def k_means(
data: NDArray,
n_clusters: int,
generator: np.random.Generator | None = None,
max_iters: int = 1000,
tolerance: float = 1e-3,
) -> KMeansResult:
"""Runs k-means.
Parameters
----------
data
Array of shape (n_samples, n_features)
n_clusters
Number of clusters
generator
Random generator (if unspecified, `np.random.default_rng()` is used)
max_iters
Maximum number of iterations before giving up
tolerance
Convergence tolerance threshold
Returns
-------
KMeansResult object
"""
n, _ = data.shape
k = n_clusters
if n < k:
msg = f"The number of points ({n}) should be at least as large as the number of centroids ({k})"
raise RuntimeError(msg)
if generator is None:
generator = np.random.default_rng()
def init_centroids(n_points: int) -> NDArray:
# TODO: Improve by using k-means++ initialization
return data[generator.choice(n, size=(n_points,), replace=False)]
centroids = init_centroids(k) # (k, p)
prev_centroids = np.full_like(centroids, np.nan)
n_iters = 0
converged = False
while n_iters < max_iters:
if converged := ((centroids - prev_centroids) ** 2).mean() <= tolerance:
break
# For each point, find the closest centroid
squared_dists = ((data[:, np.newaxis, :] - centroids) ** 2).sum(axis=-1)
closest = np.argmin(squared_dists, axis=-1)
# Update centroids
prev_centroids = centroids
centroids = np.stack([data[closest == i].mean(axis=0) for i in range(k)])
# If a centroid has no points, re-initialize it
mask = np.any(np.isnan(centroids), axis=-1)
centroids[mask] = init_centroids(mask.sum())
n_iters += 1
return KMeansResult(centroids=centroids, converged=converged, n_iters=n_iters + 1)
Converged in 6 iterations