UP | HOME

reparameterization trick

1. VAE example

Let's say that our loss involves a stochastic term. That is, it involves sampling from a Gaussian distribution, defined by our network. We can't compute a gradient over that sampling step.

In this situation, the computation graph might look like:

    f
    ^
    |
    |
    z <--- stochastic node
    ^
   / \
  /   \
sig    mu	

Then, the computation of \(f\) proceeds as:

  1. compute \(\mu\) according to the model parameters
  2. compute \(\sigma\) according to the model parameters
  3. sample \(z\) from \(\mathcal{N}(\sigma, \mu)\)
  4. compute \(f(z)\)

Now we are in trouble, because we can't do backpropagation on \(x\) and \(\phi\). With the reparameterization trick, we move the sampling step off the critical path of the back propagation:

       f
       ^
       |
       |
       z
    __ ^ __
   /   |   \
  /    |    \
sig   mu   epsilon <--- stochastic node

Here, \(\epsilon \sim \mathcal{N}(0,1)\)

Then, the procedure for computing the output at \(f\) is:

  1. compute \(\sigma\) according to the model parameters
  2. compute \(\mu\) according to the model parameters
  3. sample \(\epsilon\sim \mathcal{N}(0,1)\)
  4. compute \(z=\mu + \sigma \odot \epsilon\). Note that this is equivalent to sampling \(z\) from \(\mathcal{N}(\mu, \sigma)\)
  5. compute \(f(z)\)

Now, the gradient can flow from \(f\) to \(\sigma\) and \(\mu\)

2. sources

Created: 2024-07-15 Mon 01:26