gumbel softmax
Say that we have a neural network that defines a categorical distribution \(p(x)\) and somewhere in computing the loss for that network, we want to draw a sample from \(p(x)\). Now, we run into a problem, since we won't be able to do backprop across this stochastic choice.
To overcome this, we will use two tricks:
Our setup is this:
- Our network outputs log probabilities \(\{x_1,...,x_n\}\)
- We would like to sample from \(\{1,...,k\}\) according to the probabilities \(\{\pi_1,...,\pi_k\}\)
We use the gumbel max trick and take: \[z = \arg\max_{i} [x_1,...,x_k] \oplus [z_1,...,z_k]\] where the \(z_i\)'s are drawn i.i.d from the gumbel distribution and \(\oplus\) is element-wise addition
In the above, we are also making use of the reparameterization trick, so that the gradient can pass through to the \(x_i\)'s.
But there is still one thing wrong: we can't differentiate through the \(\arg\max\), so instead we use softmax!