The LogSumExp trick


The softmax function $\sigma$ is used to transform a vector in $\mathbb{R}^n$ to a probability vector in a monotonicity-preserving way. Specifically, if $x_i \leq x_j$, then $\sigma(x)_i \leq \sigma(x)_j$.

The softmax is typically parametrized by a “temperature” parameter $T$ to yield $\sigma_T(x) \equiv \sigma(x / T)$ which

Algebraically, the softmax is defined as

\[\sigma(x)_i \equiv \frac{\exp(x_i)}{\sum_j \exp(x_j)}.\]

This quantity is clearly continuous on $\mathbb{R}^n$ and hence finite there. However, in the presence of floating point computation, computing this quantity naively can result in blow-up:

x = np.array([768, 1024.])
exp_x = np.exp(x)
exp_x / exp_x.sum()
array([nan, nan])

The LogSumExp trick is a clever way of reformulating this computation so that it is robust to floating point error.

The LogSumExp trick

First, let $\bar{x} = \max_i x_i$ and note that


Taking logarithms,




In particular, note that $x_j - \bar{x}$ is, by construction, nonpositive and hence has a value less than one when exponentiated.

def softmax(x: np.ndarray) -> np.ndarray:
    x_max = x.max(axis=-1, keepdims=True)
    delta = x - x_max
    lse = np.log(np.exp(delta).sum(axis=-1, keepdims=True))
    return np.exp(delta - lse)
x = np.array([768, 1024.])
array([6.61626106e-112, 1.00000000e+000])