UP | HOME

gumbel softmax

Say that we have a neural network that defines a categorical distribution \(p(x)\) and somewhere in computing the loss for that network, we want to draw a sample from \(p(x)\). Now, we run into a problem, since we won't be able to do backprop across this stochastic choice.

To overcome this, we will use two tricks:

Our setup is this:

We use the gumbel max trick and take: \[z = \arg\max_{i} [x_1,...,x_k] \oplus [z_1,...,z_k]\] where the \(z_i\)'s are drawn i.i.d from the gumbel distribution and \(\oplus\) is element-wise addition

In the above, we are also making use of the reparameterization trick, so that the gradient can pass through to the \(x_i\)'s.

But there is still one thing wrong: we can't differentiate through the \(\arg\max\), so instead we use softmax!

1. source

Created: 2024-07-15 Mon 01:28