Dai and Yang 2019 – Transformer-XL: Attentive language models beyond a fixed-length context
1. Problem
Transformer can only handle of segments of a pre-defined length
2. Two key ideas
2.1. Segment-level recurrence
Basically, break the input into fixed-length inputs. Feed in one segment at a time, when processing the \(\tau + 1\) -th segment, attend also to the hidden states of the \(\tau\) -th segment. But, don't track a gradient through the \(\tau\) -th hidden states (if we did, then would run into quadratic memory scaling). This way, information about the input can propogate arbitrarily forward.
Formally, let \(\mathbf{s}_{\tau} = [x_{\tau,1},...,x_{\tau,T}]\) and \(\mathbf{s}_{\tau+1} = [x_{\tau+1, 1},..., x_{\tau+1, T}]\) be two segments of the input, where \(T\) is the pre-defined segment length.
Then, the \(n\) -th layer of the Transformer state is computed as:
\[ \tilde{\mathbf{h}}^{n-1}_{\tau + 1} = \left[SG(\mathbf{h}_{\tau}^{n-1}) \circ \mathbf{h}^{n-1}_{\tau+1} \right] \] where \(SG\) means "stop-gradient", and \(\circ\) is concatenation along the length dimension. This is the combined hidden-states of the previous segment, and the input of the current segment.
\[ \mathbf{q}^{n}_{\tau+1}, \mathbf{k}^{n}_{\tau+1}, \mathbf{v}^n_{\tau+1} = \mathbf{h}^{n-1}_{\tau+1}\mathbf{W}_{q}^{\intercal}, \tilde{\mathbf{h}}^{n-1}_{\tau+1}\mathbf{W}_{k}^{\intercal}, \tilde{\mathbf{h}}^{n-1}_{\tau+1}\mathbf{W}_{v}^{\intercal}, \]
and
\[ \mathbf{h}^{n}_{\tau+1} = \text{Transformer-Layer}(\mathbf{q}^{n}_{\tau+1}, \mathbf{k}^{n}_{\tau+1}, \mathbf{v}^{n}_{\tau+1}) \]
where \(\mathbf{W}\) are weigh matrices. Note that \(q^{n}_{\tau+1}\) only takes values from the current segment.
2.2. Relative positional embeddings
In the original Transformer, the position embeddings are absolute. To handle arbitrary-length segments, we use relative positional embeddings. Then, when the attention score is predicted for query \(i\) and key \(j\), we use a relative embedding \(\mathbf{u}_{i-j}\) for the key. Note that this involves injecting the position embedding into the attention score computation, instead of adding it to the input at the bottom layer.