Backpropagation Through Time
If you completed the exercises in :numref:sec_rnn-scratch, you would have seen that gradient clipping is vital for preventing the occasional massive gradients from destabilizing training. We hinted that the exploding gradients stem from backpropagating across long sequences. Before introducing a slew of modern RNN architectures, let's take a closer look at how backpropagation works in sequence models in mathematical detail. Hopefully, this discussion will bring some precision to the notion of vanishing and exploding gradients. If you recall our discussion of forward and backward propagation through computational graphs when we introduced MLPs in :numref:sec_backprop, then forward propagation in RNNs should be relatively straightforward. Applying backpropagation in RNNs is called backpropagation through time [160]. This procedure requires us to expand (or unroll) the computational graph of an RNN one time step at a time. The unrolled RNN is essentially a feedforward neural network with the special property that the same parameters are repeated throughout the unrolled network, appearing at each time step. Then, just as in any feedforward neural network, we can apply the chain rule, backpropagating gradients through the unrolled net. The gradient with respect to each parameter must be summed across all places that the parameter occurs in the unrolled net. Handling such weight tying should be familiar from our chapters on convolutional neural networks.
Complications arise because sequences can be rather long. It is not unusual to work with text sequences consisting of over a thousand tokens. Note that this poses problems both from a computational (too much memory) and optimization (numerical instability) standpoint. Input from the first step passes through over 1000 matrix products before arriving at the output, and another 1000 matrix products are required to compute the gradient. We now analyze what can go wrong and how to address it in practice.
Analysis of Gradients in RNNs
We start with a simplified model of how an RNN works. This model ignores details about the specifics of the hidden state and how it is updated. The mathematical notation here does not explicitly distinguish scalars, vectors, and matrices. We are just trying to develop some intuition. In this simplified model, we denote subsec_rnn_w_hidden_states that the input and the hidden state can be concatenated before being multiplied by one weight variable in the hidden layer. Thus, we use
:eqlabel:eq_bptt_ht_ot
where
For backpropagation, matters are a bit trickier, especially when we compute the gradients with regard to the parameters
:eqlabel:eq_bptt_partial_L_wh
The first and the second factors of the product in :eqref:eq_bptt_partial_L_wh are easy to compute. The third factor eq_bptt_ht_ot,
:eqlabel:eq_bptt_partial_ht_wh_recur
To derive the above gradient, assume that we have three sequences
:eqlabel:eq_bptt_at
By substituting
the gradient computation in :eqref:eq_bptt_partial_ht_wh_recur satisfies eq_bptt_at, we can remove the recurrent computation in :eqref:eq_bptt_partial_ht_wh_recur with
:eqlabel:eq_bptt_partial_ht_wh_gen
While we can use the chain rule to compute
Full Computation
One idea might be to compute the full sum in :eqref:eq_bptt_partial_ht_wh_gen. However, this is very slow and gradients can blow up, since subtle changes in the initial conditions can potentially affect the outcome a lot. That is, we could see things similar to the butterfly effect, where minimal changes in the initial conditions lead to disproportionate changes in the outcome. This is generally undesirable. After all, we are looking for robust estimators that generalize well. Hence this strategy is almost never used in practice.
Truncating Time Steps###
Alternatively, we can truncate the sum in :eqref:eq_bptt_partial_ht_wh_gen after
Randomized Truncation
Last, we can replace eq_bptt_partial_ht_wh_recur with
It follows from the definition of
Comparing Strategies
Comparing strategies for computing gradients in RNNs. From top to bottom: randomized truncation, regular truncation, and full computation.
Figure illustrates the three strategies when analyzing the first few characters of The Time Machine using backpropagation through time for RNNs:
The first row is the randomized truncation that partitions the text into segments of varying lengths.
The second row is the regular truncation that breaks the text into subsequences of the same length. This is what we have been doing in RNN experiments.
The third row is the full backpropagation through time that leads to a computationally infeasible expression.
Unfortunately, while appealing in theory, randomized truncation does not work much better than regular truncation, most likely due to a number of factors. First, the effect of an observation after a number of backpropagation steps into the past is quite sufficient to capture dependencies in practice. Second, the increased variance counteracts the fact that the gradient is more accurate with more steps. Third, we actually want models that have only a short range of interactions. Hence, regularly truncated backpropagation through time has a slight regularizing effect that can be desirable.
Backpropagation Through Time in Detail
After discussing the general principle, let's discuss backpropagation through time in detail. In contrast to the analysis in :numref:subsec_bptt_analysis, in the following we will show how to compute the gradients of the objective function with respect to all the decomposed model parameters. To keep things simple, we consider an RNN without bias parameters, whose activation function in the hidden layer uses the identity mapping (
where
In order to visualize the dependencies among model variables and parameters during computation of the RNN, we can draw a computational graph for the model, as shown in Figure. For example, the computation of the hidden states of time step 3,
Computational graph showing dependencies for an RNN model with three time steps. Boxes represent variables (not shaded) or parameters (shaded) and circles represent operators.
As just mentioned, the model parameters in Figure are sec_backprop.
First of all, differentiating the objective function with respect to the model output at any time step
:eqlabel:eq_bptt_partial_L_ot
Now we can calculate the gradient of the objective with respect to the parameter
where eq_bptt_partial_L_ot.
Next, as shown in Figure, at the final time step
:eqlabel:eq_bptt_partial_L_hT_final_step
It gets trickier for any time step
:eqlabel:eq_bptt_partial_L_ht_recur
For analysis, expanding the recurrent computation for any time step
:eqlabel:eq_bptt_partial_L_ht
We can see from :eqref:eq_bptt_partial_L_ht that this simple linear example already exhibits some key problems of long sequence models: it involves potentially very large powers of subsec_bptt_analysis. In practice, this truncation can also be effected by detaching the gradient after a given number of time steps. Later on, we will see how more sophisticated sequence models such as long short-term memory can alleviate this further.
Finally, Figure shows that the objective function
$
\begin{aligned} \frac{\partial L}{\partial \mathbf{W}\textrm{hx}} &= \sum^T \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{h}t}, \frac{\partial \mathbf{h}t}{\partial \mathbf{W}\textrm{hx}}\right) = \sum^T \frac{\partial L}{\partial \mathbf{h}t} \mathbf{x}t^\top,
\frac{\partial L}{\partial \mathbf{W}\textrm{hh}} &= \sum^T \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{h}t}, \frac{\partial \mathbf{h}t}{\partial \mathbf{W}\textrm{hh}}\right) = \sum^T \frac{\partial L}{\partial \mathbf{h}t} \mathbf{h}^\top, \end{aligned} $
where eq_bptt_partial_L_hT_final_step and :eqref:eq_bptt_partial_L_ht_recur is the key quantity that affects the numerical stability.
Since backpropagation through time is the application of backpropagation in RNNs, as we have explained in :numref:sec_backprop, training RNNs alternates forward propagation with backpropagation through time. Moreover, backpropagation through time computes and stores the above gradients in turn. Specifically, stored intermediate values are reused to avoid duplicate calculations, such as storing
Summary
Backpropagation through time is merely an application of backpropagation to sequence models with a hidden state. Truncation, such as regular or randomized, is needed for computational convenience and numerical stability. High powers of matrices can lead to divergent or vanishing eigenvalues. This manifests itself in the form of exploding or vanishing gradients. For efficient computation, intermediate values are cached during backpropagation through time.
Exercises
Assume that we have a symmetric matrix
with eigenvalues whose corresponding eigenvectors are ( ). Without loss of generality, assume that they are ordered in the order . Show that
has eigenvalues . Prove that for a random vector
, with high probability will be very much aligned with the eigenvector
of
What does the above result mean for gradients in RNNs?
Besides gradient clipping, can you think of any other methods to cope with gradient explosion in recurrent neural networks?