Policy Gradient
1. Terminology
- The framework which the agent operates in is a Markov Decision Process, which has an action space \(\mathcal{A}\), a state space \(\mathcal{S}\), a reward function \(R\), and the dynamics \(p\) of the environment
- \(p(s_{t+1} \mid s_{t}, a_{t})\) models the dynamics of the environment. "Given that I am in state \(s_t\) and take action \(a_{t}\) at time \(t\), what is the probability that I end up in \(s_{t+1}\) at time \(t+1\)?
- \(\tau\) is a trajectory which consists of states and actions for every time step: \(\{s_0, a_0, s_1, a_1, \ldots\}\)
- \(r(\tau)\) is the reward given to a trajectory
- \(\pi_{\theta}(s)\) is the policy we are trying to learn. It gives a distribution over actions, given a state \(s\). \(\theta\) are the parameters of the model
2. Gradient
The objective function is: \(J(\theta) = \mathbb{E}_{\tau \sim \pi_{\theta}}\left[r(\tau)\right]\)
The gradient is given by: \(\nabla_{\theta} J(\theta) = \mathbb{E}_{\tau\sim\pi_{\theta}}\left[ r(\tau) \sum_{t=1}^{|\tau|} \nabla \log \pi_{\theta} (a_{t} \mid s_{t}) \right]\)
2.1. Comments on the gradient
Note that the above gradient does not require us to know anything about the environment dynamics \(p(s_{t+1} \mid s_t, a_t)\) or the probability \(p(s_1)\) that we start a trajectory in any given state. Algorithms that use this result are called model-free.
3. Derivation of gradient
\[\begin{align*} \nabla_{\theta} J(\theta) &= \nabla_{\theta} \mathbb{E}_{\tau\sim\pi_{\theta}}\left[ r(\tau) \right]\\ &= \int \nabla_{\theta} \pi_{\theta}(\tau) r(\tau) d\tau\\ &= \int \pi_{\theta}(\tau) \nabla_{\theta} \log \pi_{\theta}(\tau) r(\tau) d\tau\\ &= \mathbb{E}_{\tau \sim \pi_{\theta}} \left[ \nabla_{\theta} \log \pi_{\theta}(\tau) r(\tau)\right] \end{align*}\]
The third line comes from the fact that \(\pi_{\theta}(\tau)\nabla_{\theta} \log \pi_{\theta}(\tau) = \pi_{\theta}(\tau) \frac{1}{\pi_{\theta}(\tau)}\nabla \pi_{\theta}(\tau) = \nabla_{\theta} \pi_{\theta}(\tau)\)
But what is \(\pi_{\theta}(\tau)\)? It is given by: \[\begin{align*} \pi_{\theta}(\tau) &= \pi_{\theta}(s_1,a_1,s_2,a_2,\ldots) \\ &= p(s_1)\prod_{t=1}^{|\tau|} \pi(a_{t} \mid s_{t}) p(s_{t+1} \mid a_t) \end{align*}\] where \(p(s_1)\) is the probability of starting in state \(s_1\).
We can take the log of the above to get: \[\log \pi_{\theta}(\tau) = \log p(s_1) + \sum_{t=1}^{|\tau|} \pi(a_{t} \mid s_{t}) + \sum_{t=1}^{|\tau|} p(s_{t+1} \mid a_t)\]
Let's substitute this into the gradient: \[\begin{align*} \mathbb{E}_{\tau\sim\pi_{\theta}} \left[\nabla_{\theta}\log \pi_{\theta}(\tau) r(\tau) \right] & = \mathbb{E}_{\tau\sim\pi_{\theta}} \left[ \nabla_{\theta}\left(\log p(s_1) + \sum_{t=1}^{|\tau|} \pi(a_{t} \mid s_{t}) + \sum_{t=1}^{|\tau|} p(s_{t+1} \mid a_t) \right) \right]\\ & = \mathbb{E}_{\tau\sim\pi_{\theta}} \left[r(\tau)\sum_{t=1}^{|\tau|} \nabla_{\theta}\pi(a_{t} \mid s_{t}) \right] \end{align*}\]
And we are done! The last line comes from the fact that any term that doesn't depend on \(\theta\) disappears when we take the gradient \(\nabla_{\theta}\).
4. Algorithms
The expectation for the gradient can be estimated using Monte-Carlo methods. This is the approach taken by the REINFORCE algorithm.
5. Actor-Critic
- See Actor Critic
6. off policy
Instead of sampling trajectories according to the policy being learned, you can sample according to a separate exploration policy. See off-policy policy gradient.
7. Comparison to maximum likelihood for sequence learning
For sequences:
- All of the input is consumed by an encoder, before the decoder begins producing output
- Let \(x\) be the source sequence, and \(y\) be the target sequence
For sequences, the goal of Maximum Likelihood Estimation is to find parameters \(\theta^*\): \[\theta^* = \arg\max_{\theta} p_{\theta}(y \mid x)\] where \[p_{\theta}(y\mid x) = \prod_t p_{\theta}(y_t \mid x)\] when maximizing the log likelihood of a dataset \(\mathcal{D}\), our objective is then: \[\frac{1}{\mathcal{D}}\sum_{(x,y)\in\mathcal{D}} \nabla \log p(y\mid x)\]
Compare this to a Monte-Carlo estimator of the RL gradient, which samples \(N\) tracks: \[\frac{1}{N}\sum_{i=1}^{N} r(\tau) \nabla \log p(a_{\tau} \mid x)\] Note that in the above, the actions \(a_{\tau}\) is the output, which lies in the target domain. In contrast to the RL framework discussed above, the state \(s_{t}\) is not conditioned over at every time-step. Instead, the input sequence \(x\) is fed to an encoder, which passes an embedding of the input to the decoder. The decoder produces the actions/output tokens.
The two objectives are identical in the case where \(r(.)\) is 1 for the target output and 0 otherwise. When this is not the case, I think of RL as a fuzzy, trial-and-error approximation of MLE.