The LogSumExp trick
Motivation
The softmax function $\sigma$ is used to transform a vector in $\mathbb{R}^n$ to a probability vector in a monotonicity-preserving way. Specifically, if $x_i \leq x_j$, then $\sigma(x)_i \leq \sigma(x)_j$.
The softmax is typically parametrized by a “temperature” parameter $T$ to yield $\sigma_T(x) \equiv \sigma(x / T)$ which
- shifts more probability mass to the largest component of $x$ as the temperature decays to zero and
- distributes the mass more evenly among the components of $x$ as the temperature grows.
More details regarding the temperature can be found in a previous blog post.
Algebraically, the softmax is defined as
\[\sigma(x)_i \equiv \frac{\exp(x_i)}{\sum_j \exp(x_j)}.\]This quantity is clearly continuous on $\mathbb{R}^n$ and hence finite there. However, in the presence of floating point computation, computing this quantity naively can result in blow-up:
x = np.array([768, 1024.])
exp_x = np.exp(x)
exp_x / exp_x.sum()
/tmp/ipykernel_117792/4003806838.py:1: RuntimeWarning: overflow encountered in exp
exp_x = np.exp(x)
/tmp/ipykernel_117792/4003806838.py:2: RuntimeWarning: invalid value encountered in divide
exp_x / exp_x.sum()
array([nan, nan])
The LogSumExp trick is a clever way of reformulating this computation so that it is robust to floating point error.
The LogSumExp trick
First, let $\bar{x} = \max_i x_i$ and note that
\[\sigma(x)_{i}=\frac{\exp(x_{i}-\bar{x})}{\sum_{j}\exp(x_{j}-\bar{x})}.\]Taking logarithms,
\[\log(\sigma(x)_{i})=x_{i}-\bar{x}-\log\biggl(\sum_{j}\exp(x_{j}-\bar{x})\biggr).\]Exponentiating,
\[\sigma(x)_{i}=\exp\biggr(x_{i}-\bar{x}-\log\biggl(\sum_{j}\exp(x_{j}-\bar{x})\biggr)\biggr).\]In particular, note that $x_j - \bar{x}$ is, by construction, nonpositive and hence has a value less than one when exponentiated.
def softmax(x: np.ndarray) -> np.ndarray:
x_max = x.max(axis=-1, keepdims=True)
delta = x - x_max
lse = np.log(np.exp(delta).sum(axis=-1, keepdims=True))
return np.exp(delta - lse)
x = np.array([768, 1024.])
softmax(x)
array([6.61626106e-112, 1.00000000e+000])