UP | HOME

gumbel max trick

1. goal

We have categorical random variable \(\{1,...,k\}\) with associated log-probabilities \(\{x_1,...,x_k\}\) that we want to sample from.

2. usual way

Usually, we would exponentiate and then normalize: \[ \pi_k = \frac{\exp(x_k)}{\sum_{j} \exp(x_j)} \] (Note that this is the same thing we do when we take the softmax)

3. trick

Instead, take: \[ y = \arg\max_{i\in\{1..k\}} x_k + z_k \] where \(z_k \sim \text{Gumbel}(0,1)\) are i.i.d. drawn from the standard Gumbel

It turns out that sampling \(y\) this way results in the same distribution as sampling \(y\) according to the \(\pi_k\)'s!

Proof here

4. Source

Created: 2024-07-15 Mon 01:27