Parsiad Azimzadeh

The LogSumExp trick

Motivation

The softmax function $\sigma$ is used to transform a vector in $\mathbb{R}^n$ to a probability vector in a monotonicity-preserving way. Specifically, if $x_i \leq x_j$, then $\sigma(x)_i \leq \sigma(x)_j$.

The softmax is typically parametrized by a "temperature" parameter $T$ to yield $\sigma_T(x) \equiv \sigma(x / T)$ which

More details regarding the temperature can be found in a previous blog post.

Algebraically, the softmax is defined as

$$ \sigma(x)_i \equiv \frac{\exp(x_i)}{\sum_j \exp(x_j)}. $$

This quantity is clearly continuous on $\mathbb{R}^n$ and hence finite there. However, in the presence of floating point computation, computing this quantity naively can result in blow-up:

x = np.array([768, 1024.0])
exp_x = np.exp(x)
exp_x / exp_x.sum()
/tmp/ipykernel_7314/152377539.py:2: RuntimeWarning: overflow encountered in exp
  exp_x = np.exp(x)
/tmp/ipykernel_7314/152377539.py:3: RuntimeWarning: invalid value encountered in divide
  exp_x / exp_x.sum()





array([nan, nan])

The LogSumExp trick is a clever way of reformulating this computation so that it is robust to floating point error.

The LogSumExp trick

First, let $\bar{x} = \max_i x_i$ and note that

$$ \sigma(x)_{i}=\frac{\exp(x_{i}-\bar{x})}{\sum_{j}\exp(x_{j}-\bar{x})}. $$

Taking logarithms,

$$ \log(\sigma(x)_{i})=x_{i}-\bar{x}-\log\biggl(\sum_{j}\exp(x_{j}-\bar{x})\biggr). $$

Exponentiating,

$$ \sigma(x)_{i}=\exp\biggr(x_{i}-\bar{x}-\log\biggl(\sum_{j}\exp(x_{j}-\bar{x})\biggr)\biggr). $$

In particular, note that $x_j - \bar{x}$ is, by construction, nonpositive and hence has a value less than one when exponentiated.

def softmax(x: np.ndarray) -> np.ndarray:
    x_max = x.max(axis=-1, keepdims=True)
    delta = x - x_max
    lse = np.log(np.exp(delta).sum(axis=-1, keepdims=True))
    return np.exp(delta - lse)
x = np.array([768, 1024.0])
softmax(x)
array([6.61626106e-112, 1.00000000e+000])