Recurrent Neural Network Implementation from Scratch
We are now ready to implement an RNN from scratch. In particular, we will train this RNN to function as a character-level language model (see :numref:sec_rnn) and train it on a corpus consisting of the entire text of H. G. Wells' The Time Machine, following the data processing steps outlined in :numref:sec_text-sequence. We start by loading the dataset.
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN Activating project at `/workspace/d2l-julia/d2lai`RNN Model
We begin by defining a class to implement the RNN model (:numref:subsec_rnn_w_hidden_states). Note that the number of hidden units num_hiddens is a tunable hyperparameter.
struct RNNScratch{Wx, Wh, Bh, A} <: AbstractModel
Whx::Wx
Whh::Wh
b_h::Bh
args::A
end
Flux.@layer RNNScratch trainable = (Whx, Whh, b_h)
function RNNScratch(num_inputs::Int, num_hiddens::Int; sigma = 0.01)
Whx = randn(num_hiddens, num_inputs).*sigma
Whh = randn(num_hiddens, num_hiddens).*sigma
b_h = zeros(num_hiddens)
RNNScratch(Whx, Whh, b_h, (num_inputs = num_inputs, num_hiddens = num_hiddens, sigma = sigma))
endRNNScratchThe RNNScratch method below defines how to compute the output and hidden state at any time step, given the current input and the state of the model at the previous time step. Note that the RNN model loops through the second dimension of inputs, updating the hidden state one time step at a time. The model here uses a subsec_tanh).
function (rnn::RNNScratch)(x::AbstractArray, state = nothing)
batchsize = size(x, 3)
device = isa(x, CuArray) ? gpu : cpu
state = if isnothing(state)
zeros(rnn.args.num_hiddens, size(x, 3))
else
state
end |> device
outputs = map(eachslice(x, dims = 2)) do x_
state = tanh.(rnn.Whx*x_ + rnn.Whh*state .+ rnn.b_h)
return state
end
outputs_cat = stack(outputs)
return permutedims(outputs_cat, [1, 3, 2]), state # num_hiddens x num_steps x batchsize, num_hiddens x batchsize
endWe can feed a minibatch of input sequences into an RNN model as follows.
batch_size, num_inputs, num_hiddens, num_steps = 2, 16, 32, 100
rnn = RNNScratch(num_inputs, num_hiddens)
X = ones((num_inputs, num_steps, batch_size))
outputs, state = rnn(X)([0.03149590863711263 0.026365868884975295 … 0.026380686362776472 0.026380686362776472; -0.0030060672354854214 -0.003511913636600897 … -0.0035204783968947237 -0.0035204783968947237; … ; 0.05496053938912919 0.05292253724661685 … 0.05305652141664291 0.05305652141664291; 0.006091443127899462 0.004483091327932017 … 0.004324985453887874 0.004324985453887874;;; 0.03149590863711263 0.026365868884975295 … 0.026380686362776472 0.026380686362776472; -0.0030060672354854214 -0.003511913636600897 … -0.0035204783968947237 -0.0035204783968947237; … ; 0.05496053938912919 0.05292253724661685 … 0.05305652141664291 0.05305652141664291; 0.006091443127899462 0.004483091327932017 … 0.004324985453887874 0.004324985453887874], [0.026380686362776472 0.026380686362776472; -0.0035204783968947237 -0.0035204783968947237; … ; 0.05305652141664291 0.05305652141664291; 0.004324985453887874 0.004324985453887874])Let’s check whether the RNN model produces results of the correct shapes to ensure that the dimensionality of the hidden state remains unchanged.
@assert size(outputs, 2) == num_steps
@assert size(outputs[:, 1, :]) == (num_hiddens, batch_size)
@assert size(outputs) == (num_hiddens, num_steps, batch_size)RNN-Based Language Model
The following RNNLMScratch class defines an RNN-based language model, where we pass in our RNN via the rnn argument of the constructor method. When training language models, the inputs and outputs are from the same vocabulary. Hence, they have the same dimension, which is equal to the vocabulary size. Note that we use perplexity to evaluate the model. As discussed in :numref:subsec_perplexity, this ensures that sequences of different length are comparable.
abstract type AbstractRNNClassifier <: AbstractClassifier end
struct RNNLMScratch{R, W, B, A} <: AbstractRNNClassifier
rnn::R
Wq::W
bq::B
args::A
end
Flux.@layer RNNLMScratch trainable = (rnn, Wq, bq)
function RNNLMScratch(rnn, vocab_size)
Wq = randn(vocab_size, rnn.args.num_hiddens)*rnn.args.sigma
bq = zeros(vocab_size)
RNNLMScratch(rnn, Wq, bq, (vocab_size=vocab_size,))
end
function d2lai.loss(m::AbstractRNNClassifier, y_pred, y)
Flux.logitcrossentropy(y_pred, Flux.onehotbatch(y, 1:m.args.vocab_size))
end
function d2lai.training_step(m::AbstractRNNClassifier, batch)
y_pred = d2lai.forward(m, batch[1])
loss_ = d2lai.loss(m, y_pred, batch[end])
return loss_
end
function d2lai.validation_step(m::AbstractRNNClassifier, batch)
y_pred = d2lai.forward(m, batch[1])
loss_ = d2lai.loss(m, y_pred, batch[end])
return loss_ , nothing
endTransforming RNN Outputs
The language model uses a fully connected output layer to transform RNN outputs into token predictions at each time step.
function output_layer(m::RNNLMScratch, x)
outs = map(eachslice(x, dims =2)) do x_
m.Wq*x_ .+ m.bq
end
outs = stack(outs)
return permutedims(outs, [1, 3, 2])
end
function (rnnlm::RNNLMScratch)(x, state = nothing)
output, _ = rnnlm.rnn(x, state)
output_layer(rnnlm, output)
endLet's check whether the forward computation produces outputs with the correct shape.
model = RNNLMScratch(rnn, num_inputs)
output = model(ones(num_inputs, num_steps, batch_size))
@assert size(output) == (num_inputs, num_steps, batch_size)Gradient Clipping
While you are already used to thinking of neural networks as "deep" in the sense that many layers separate the input and output even within a single time step, the length of the sequence introduces a new notion of depth. In addition to the passing through the network in the input-to-output direction, inputs at the first time step must pass through a chain of sec_numerical_stability, this can result in numerical instability, causing the gradients either to explode or vanish, depending on the properties of the weight matrices.
Dealing with vanishing and exploding gradients is a fundamental problem when designing RNNs and has inspired some of the biggest advances in modern neural network architectures. In the next chapter, we will talk about specialized architectures that were designed in hopes of mitigating the vanishing gradient problem. However, even modern RNNs often suffer from exploding gradients. One inelegant but ubiquitous solution is to simply clip the gradients forcing the resulting "clipped" gradients to take smaller values.
Generally speaking, when optimizing some objective by gradient descent, we iteratively update the parameter of interest, say a vector
As you can see, when we update the parameter vector by subtracting
In other words, the objective cannot change by more than
When we say that gradients explode, we mean that
One way to limit the size of
(
This ensures that the gradient norm never exceeds
Below we define a method to clip gradients, which is invoked by the fit_epoch method of the d2l.Trainer class (see :numref:sec_linear_scratch). Note that when computing the gradient norm, we are concatenating all model parameters, treating them as a single giant parameter vector.
function d2lai.clip_gradients!(gs, gradient_clip_val, model)
sums = fmap(gs, walk=Flux.Functors.IterateWalk()) do g
!isnothing(g) && sum(g.^2)
end
norm = sqrt(sum(filter(x -> x != false, collect(sums))))
g_ = fmap(gs) do d
if !isnothing(d)
d = d.* (1. / norm)
end
end
g_
endTraining
Using The Time Machine dataset (data), we train a character-level language model (model) based on the RNN (rnn) implemented from scratch. Note that we first calculate the gradients, then clip them, and finally update the model parameters using the clipped gradients.
data = d2lai.TimeMachine(1024, 32) |> f64
num_hiddens = 32
rnn = RNNScratch(length(data.vocab), num_hiddens)
model = RNNLMScratch(rnn, length(data.vocab)) |> f64
opt = Flux.Optimiser(Descent(1.))
trainer = Trainer(model, data, opt; max_epochs = 100, gpu = true, gradient_clip_val = 1., board_yscale = :identity)
m, _ = d2lai.fit(trainer);┌ Warning: `Flux.Optimiser(...)` has been removed, please call `OptimiserChain(...)`, exported by Flux from Optimisers.jl
└ @ Flux ~/.julia/packages/Flux/3711C/src/deprecations.jl:123 [ Info: Train Loss: 2.7970488, Val Loss: 2.7806466
[ Info: Train Loss: 2.5299044, Val Loss: 2.5298347
[ Info: Train Loss: 2.4346051, Val Loss: 2.4459527
[ Info: Train Loss: 2.386652, Val Loss: 2.417434
[ Info: Train Loss: 2.356393, Val Loss: 2.3768516
[ Info: Train Loss: 2.3064065, Val Loss: 2.3563817
[ Info: Train Loss: 2.2758873, Val Loss: 2.3187153
[ Info: Train Loss: 2.250458, Val Loss: 2.3117213
[ Info: Train Loss: 2.2068582, Val Loss: 2.285258
[ Info: Train Loss: 2.1796353, Val Loss: 2.2712932
[ Info: Train Loss: 2.168831, Val Loss: 2.2566879
[ Info: Train Loss: 2.139044, Val Loss: 2.2405543
[ Info: Train Loss: 2.1123457, Val Loss: 2.214762
[ Info: Train Loss: 2.0906117, Val Loss: 2.2032485
[ Info: Train Loss: 2.0699565, Val Loss: 2.1767392
[ Info: Train Loss: 2.0620189, Val Loss: 2.170835
[ Info: Train Loss: 2.0402772, Val Loss: 2.161874
[ Info: Train Loss: 2.0210266, Val Loss: 2.155169
[ Info: Train Loss: 2.0138662, Val Loss: 2.1288521
[ Info: Train Loss: 1.9950365, Val Loss: 2.1473613
[ Info: Train Loss: 1.9791988, Val Loss: 2.1329134
[ Info: Train Loss: 1.9536208, Val Loss: 2.1350038
[ Info: Train Loss: 1.9502233, Val Loss: 2.1203203
[ Info: Train Loss: 1.9379421, Val Loss: 2.117951
[ Info: Train Loss: 1.9315099, Val Loss: 2.1151593
[ Info: Train Loss: 1.9132171, Val Loss: 2.1154299
[ Info: Train Loss: 1.9058115, Val Loss: 2.1053178
[ Info: Train Loss: 1.8888556, Val Loss: 2.0996616
[ Info: Train Loss: 1.8528732, Val Loss: 2.072391
[ Info: Train Loss: 1.8539463, Val Loss: 2.106058
[ Info: Train Loss: 1.850917, Val Loss: 2.080681
[ Info: Train Loss: 1.8432741, Val Loss: 2.0786562
[ Info: Train Loss: 1.8390543, Val Loss: 2.0741498
[ Info: Train Loss: 1.8201522, Val Loss: 2.0682487
[ Info: Train Loss: 1.8261395, Val Loss: 2.073867
[ Info: Train Loss: 1.8438075, Val Loss: 2.066239
[ Info: Train Loss: 1.8288956, Val Loss: 2.0952883
[ Info: Train Loss: 1.8231257, Val Loss: 2.077878
[ Info: Train Loss: 1.8426485, Val Loss: 2.0822954
[ Info: Train Loss: 1.8010774, Val Loss: 2.0722723
[ Info: Train Loss: 1.8078551, Val Loss: 2.0517821
[ Info: Train Loss: 1.8137381, Val Loss: 2.1038642
[ Info: Train Loss: 1.8033065, Val Loss: 2.0715206
[ Info: Train Loss: 1.7908123, Val Loss: 2.0784037
[ Info: Train Loss: 1.7791256, Val Loss: 2.058735
[ Info: Train Loss: 1.7790664, Val Loss: 2.0730052
[ Info: Train Loss: 1.8065618, Val Loss: 2.081089
[ Info: Train Loss: 1.806606, Val Loss: 2.0855896
[ Info: Train Loss: 1.7907915, Val Loss: 2.0717375
[ Info: Train Loss: 1.7636669, Val Loss: 2.0648215
[ Info: Train Loss: 1.76411, Val Loss: 2.0652728
[ Info: Train Loss: 1.7750646, Val Loss: 2.057261
[ Info: Train Loss: 1.746128, Val Loss: 2.063495
[ Info: Train Loss: 1.7507931, Val Loss: 2.0503523
[ Info: Train Loss: 1.7598104, Val Loss: 2.052884
[ Info: Train Loss: 1.763704, Val Loss: 2.033104
[ Info: Train Loss: 1.7361344, Val Loss: 2.0438612
[ Info: Train Loss: 1.7480404, Val Loss: 2.0504959
[ Info: Train Loss: 1.7418566, Val Loss: 2.0372424
[ Info: Train Loss: 1.7333186, Val Loss: 2.0370202
[ Info: Train Loss: 1.7357863, Val Loss: 2.0487626
[ Info: Train Loss: 1.7316349, Val Loss: 2.068049
[ Info: Train Loss: 1.722281, Val Loss: 2.0725887
[ Info: Train Loss: 1.720167, Val Loss: 2.0383306
[ Info: Train Loss: 1.7335465, Val Loss: 2.0344255
[ Info: Train Loss: 1.7281626, Val Loss: 2.0483496
[ Info: Train Loss: 1.7171172, Val Loss: 2.0390682
[ Info: Train Loss: 1.7396109, Val Loss: 2.0438938
[ Info: Train Loss: 1.7141985, Val Loss: 2.0546248
[ Info: Train Loss: 1.7239172, Val Loss: 2.0498676
[ Info: Train Loss: 1.7217245, Val Loss: 2.0324197
[ Info: Train Loss: 1.6997478, Val Loss: 2.037931
[ Info: Train Loss: 1.7096909, Val Loss: 2.046294
[ Info: Train Loss: 1.7162064, Val Loss: 2.0405195
[ Info: Train Loss: 1.7022381, Val Loss: 2.069118
[ Info: Train Loss: 1.7016252, Val Loss: 2.029191
[ Info: Train Loss: 1.6949564, Val Loss: 2.037418
[ Info: Train Loss: 1.7046666, Val Loss: 2.0562437
[ Info: Train Loss: 1.70282, Val Loss: 2.054844
[ Info: Train Loss: 1.7028077, Val Loss: 2.0440593
[ Info: Train Loss: 1.6855196, Val Loss: 2.0636458
[ Info: Train Loss: 1.6911688, Val Loss: 2.022713
[ Info: Train Loss: 1.7089508, Val Loss: 2.0428846
[ Info: Train Loss: 1.7001234, Val Loss: 2.0379603
[ Info: Train Loss: 1.6844562, Val Loss: 2.0555046
[ Info: Train Loss: 1.6850165, Val Loss: 2.0249376
[ Info: Train Loss: 1.6774806, Val Loss: 2.024642
[ Info: Train Loss: 1.6732417, Val Loss: 2.0571303
[ Info: Train Loss: 1.6991328, Val Loss: 2.0451713
[ Info: Train Loss: 1.6776098, Val Loss: 2.0474555
[ Info: Train Loss: 1.6890179, Val Loss: 2.0362225
[ Info: Train Loss: 1.6613668, Val Loss: 2.0548596
[ Info: Train Loss: 1.6635011, Val Loss: 2.0376565
[ Info: Train Loss: 1.662441, Val Loss: 2.0445817
[ Info: Train Loss: 1.6780094, Val Loss: 2.0266523
[ Info: Train Loss: 1.6712301, Val Loss: 2.0285335
[ Info: Train Loss: 1.6786494, Val Loss: 2.0372517
[ Info: Train Loss: 1.6722633, Val Loss: 2.0376847
[ Info: Train Loss: 1.6780833, Val Loss: 2.0292208
[ Info: Train Loss: 1.6620464, Val Loss: 2.0444481Decoding
Once a language model has been learned, we can use it not only to predict the next token but to continue predicting each subsequent one, treating the previously predicted token as though it were the next in the input. Sometimes we will just want to generate text as though we were starting at the beginning of a document. However, it is often useful to condition the language model on a user-supplied prefix. For example, if we were developing an autocomplete feature for a search engine or to assist users in writing emails, we would want to feed in what they had written so far (the prefix), and then generate a likely continuation.
The following predict method generates a continuation, one character at a time, after ingesting a user-provided prefix. When looping through the characters in prefix, we keep passing the hidden state to the next time step but do not generate any output. This is called the warm-up period. After ingesting the prefix, we are now ready to begin emitting the subsequent characters, each of which will be fed back into the model as the input at the next time step.
function prediction(prefix, model, vocab, num_preds)
outputs = [vocab.token_to_idx[string(prefix[1])]]
state = zeros(32)
for i in 2:length(prefix)
x = outputs[end]
x = reshape(Flux.onehotbatch(x, 1:length(vocab)), :, 1, 1)
_, state = model.rnn(x, state)
push!(outputs, vocab.token_to_idx[string(prefix[i])])
end
for i in 1:num_preds
x = outputs[end]
x = reshape(Flux.onehotbatch(x, 1:length(vocab)), :, 1, 1)
out, state = model.rnn(x, state)
out = output_layer(model, out)
idx = argmax(softmax(out), dims = 1)[1][1]
push!(outputs, idx)
end
out_chars = map(outputs) do o
vocab.idx_to_token[o]
end
join(out_chars)
endprediction (generic function with 1 method)In the following, we specify the prefix and have it generate 20 additional characters.
prefix = "it has"
prediction(prefix, m, data.vocab, 20)"it has the time traveller "While implementing the above RNN model from scratch is instructive, it is not convenient. In the next section, we will see how to leverage deep learning frameworks to whip up RNNs using standard architectures, and to reap performance gains by relying on highly optimized library functions.
Summary
We can train RNN-based language models to generate text following the user-provided text prefix. A simple RNN language model consists of input encoding, RNN modeling, and output generation. During training, gradient clipping can mitigate the problem of exploding gradients but does not address the problem of vanishing gradients. In the experiment, we implemented a simple RNN language model and trained it with gradient clipping on sequences of text, tokenized at the character level. By conditioning on a prefix, we can use a language model to generate likely continuations, which proves useful in many applications, e.g., autocomplete features.
Exercises
Does the implemented language model predict the next token based on all the past tokens up to the very first token in The Time Machine?
Which hyperparameter controls the length of history used for prediction?
Show that one-hot encoding is equivalent to picking a different embedding for each object.
Adjust the hyperparameters (e.g., number of epochs, number of hidden units, number of time steps in a minibatch, and learning rate) to improve the perplexity. How low can you go while sticking with this simple architecture?
Replace one-hot encoding with learnable embeddings. Does this lead to better performance?
Conduct an experiment to determine how well this language model trained on The Time Machine works on other books by H. G. Wells, e.g., The War of the Worlds.
Conduct another experiment to evaluate the perplexity of this model on books written by other authors.
Modify the prediction method so as to use sampling rather than picking the most likely next character.
What happens?
Bias the model towards more likely outputs, e.g.,
by sampling from
Run the code in this section without clipping the gradient. What happens?
Replace the activation function used in this section with ReLU and repeat the experiments in this section. Do we still need gradient clipping? Why?