Skip to content

Concise Implementation of Recurrent Neural Networks ​

Like most of our from-scratch implementations, :numref:sec_rnn-scratch was designed to provide insight into how each component works. But when you are using RNNs every day or writing production code, you will want to rely more on libraries that cut down on both implementation time (by supplying library code for common models and functions) and computation time (by optimizing the heck out of these library implementations). This section will show you how to implement the same language model more efficiently using the high-level API provided by your deep learning framework. We begin, as before, by loading The Time Machine dataset.

julia
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux 
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN
  Activating project at `/workspace/d2l-julia/d2lai`
    [ Info: Precompiling d2lai [749b8817-cd67-416c-8a57-830ea19f3cc4] (cache misses: include_dependency fsize change (2))

Defining the Model ​

We define the following class using the RNN implemented by high-level APIs.

julia
num_hiddens = 32 
input_size = 28
rnn = Flux.RNN(input_size => num_hiddens, return_state = true)
RNN(28 => 32, tanh)  # 1_952 parameters

Inheriting from the AbstractRNNClassifier class in :numref:sec_rnn-scratch, the following RNNLM class defines a complete RNN-based language model. Note that we need to create a separate fully connected output layer.

julia
struct RNNLM{N,R,A} <: d2lai.AbstractRNNClassifier 
    net::N 
    rnn::R 
    args::A
end

Flux.@layer RNNLM trainable = (net, rnn)
function RNNLM(rnn::Flux.RNN, num_hiddens::Int, vocab_size::Int)
    net = Dense(num_hiddens => vocab_size)
    return RNNLM(net, rnn, (num_hiddens = num_hiddens, vocab_size = vocab_size))
end

function d2lai.output_layer(m::RNNLM, out)
    m.net(out)
end

function (m::RNNLM)(x)
    out = m.rnn(x)[1]
    return d2lai.output_layer(m, out)
end

Training and Predicting ​

Before training the model, let's make a prediction with a model initialized with random weights. Given that we have not trained the network, it will generate nonsensical predictions.

julia
num_hiddens = 32
data = d2lai.TimeMachine(1024, 32) |> f64
rnn = Flux.RNN(length(data.vocab) => num_hiddens, return_state = true)
model = RNNLM(rnn, num_hiddens, length(data.vocab)) |> f64
prefix = "it has"
d2lai.prediction(prefix, model, data.vocab, 20, state = zeros(num_hiddens))
"it hasnfjqoqvsoqvszqk<unk>zbgs"
julia
opt = Descent(1.)
trainer = Trainer(model, data, opt; max_epochs = 100, gpu = true, board_yscale = :identity, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer)
    [ Info: Train Loss: 2.8086967, Val Loss: 2.8176572
    [ Info: Train Loss: 2.5578668, Val Loss: 2.5525599
    [ Info: Train Loss: 2.4050918, Val Loss: 2.4460049
    [ Info: Train Loss: 2.316194, Val Loss: 2.368135
    [ Info: Train Loss: 2.2381124, Val Loss: 2.332166
    [ Info: Train Loss: 2.1993399, Val Loss: 2.3082745
    [ Info: Train Loss: 2.1650236, Val Loss: 2.2931821
    [ Info: Train Loss: 2.1162374, Val Loss: 2.2517207
    [ Info: Train Loss: 2.1146615, Val Loss: 2.2625127
    [ Info: Train Loss: 2.0888648, Val Loss: 2.2390275
    [ Info: Train Loss: 2.0716052, Val Loss: 2.230474
    [ Info: Train Loss: 2.0240247, Val Loss: 2.2290487
    [ Info: Train Loss: 2.0295258, Val Loss: 2.2020535
    [ Info: Train Loss: 1.9899576, Val Loss: 2.1624563
    [ Info: Train Loss: 2.004797, Val Loss: 2.1773639
    [ Info: Train Loss: 1.9721513, Val Loss: 2.1735885
    [ Info: Train Loss: 1.9659573, Val Loss: 2.1581883
    [ Info: Train Loss: 1.9484462, Val Loss: 2.1393445
    [ Info: Train Loss: 1.930977, Val Loss: 2.1594217
    [ Info: Train Loss: 1.9451939, Val Loss: 2.1503987
    [ Info: Train Loss: 1.9346877, Val Loss: 2.114789
    [ Info: Train Loss: 1.8983876, Val Loss: 2.1232784
    [ Info: Train Loss: 1.8946356, Val Loss: 2.1232986
    [ Info: Train Loss: 1.8931116, Val Loss: 2.0921307
    [ Info: Train Loss: 1.8924788, Val Loss: 2.113529
    [ Info: Train Loss: 1.8815277, Val Loss: 2.0992854
    [ Info: Train Loss: 1.8664699, Val Loss: 2.0718553
    [ Info: Train Loss: 1.8620858, Val Loss: 2.0682604
    [ Info: Train Loss: 1.8730401, Val Loss: 2.0869172
    [ Info: Train Loss: 1.8470776, Val Loss: 2.06626
    [ Info: Train Loss: 1.8517028, Val Loss: 2.082669
    [ Info: Train Loss: 1.839059, Val Loss: 2.0598977
    [ Info: Train Loss: 1.8291104, Val Loss: 2.0723095
    [ Info: Train Loss: 1.8454365, Val Loss: 2.0628722
    [ Info: Train Loss: 1.8226379, Val Loss: 2.0716498
    [ Info: Train Loss: 1.818457, Val Loss: 2.0575447
    [ Info: Train Loss: 1.8064052, Val Loss: 2.0661533
    [ Info: Train Loss: 1.7888063, Val Loss: 2.0482213
    [ Info: Train Loss: 1.8107212, Val Loss: 2.0582194
    [ Info: Train Loss: 1.8009559, Val Loss: 2.0606222
    [ Info: Train Loss: 1.8003632, Val Loss: 2.062232
    [ Info: Train Loss: 1.793633, Val Loss: 2.0593336
    [ Info: Train Loss: 1.7663118, Val Loss: 2.0475314
    [ Info: Train Loss: 1.776362, Val Loss: 2.0510414
    [ Info: Train Loss: 1.7802719, Val Loss: 2.0557463
    [ Info: Train Loss: 1.7893523, Val Loss: 2.061776
    [ Info: Train Loss: 1.7992761, Val Loss: 2.0471537
    [ Info: Train Loss: 1.7724043, Val Loss: 2.0530665
    [ Info: Train Loss: 1.7678832, Val Loss: 2.043026
    [ Info: Train Loss: 1.7517827, Val Loss: 2.0457454
    [ Info: Train Loss: 1.7520797, Val Loss: 2.0534086
    [ Info: Train Loss: 1.735371, Val Loss: 2.0413203
    [ Info: Train Loss: 1.741371, Val Loss: 2.0482664
    [ Info: Train Loss: 1.7486898, Val Loss: 2.0389254
    [ Info: Train Loss: 1.7638668, Val Loss: 2.0467787
    [ Info: Train Loss: 1.7438501, Val Loss: 2.0521975
    [ Info: Train Loss: 1.7357632, Val Loss: 2.043563
    [ Info: Train Loss: 1.7333704, Val Loss: 2.0436978
    [ Info: Train Loss: 1.762108, Val Loss: 2.0362566
    [ Info: Train Loss: 1.7176626, Val Loss: 2.0604687
    [ Info: Train Loss: 1.717186, Val Loss: 2.0518222
    [ Info: Train Loss: 1.7292625, Val Loss: 2.0569487
    [ Info: Train Loss: 1.7006977, Val Loss: 2.0418108
    [ Info: Train Loss: 1.7316309, Val Loss: 2.0493767
    [ Info: Train Loss: 1.69463, Val Loss: 2.054217
    [ Info: Train Loss: 1.720155, Val Loss: 2.0379403
    [ Info: Train Loss: 1.7145555, Val Loss: 2.0562189
    [ Info: Train Loss: 1.7215083, Val Loss: 2.0379372
    [ Info: Train Loss: 1.7383448, Val Loss: 2.0564325
    [ Info: Train Loss: 1.7068391, Val Loss: 2.0735323
    [ Info: Train Loss: 1.7159855, Val Loss: 2.06642
    [ Info: Train Loss: 1.709202, Val Loss: 2.0742066
    [ Info: Train Loss: 1.7023453, Val Loss: 2.0456624
    [ Info: Train Loss: 1.6985327, Val Loss: 2.0623722
    [ Info: Train Loss: 1.6987822, Val Loss: 2.0451224
    [ Info: Train Loss: 1.7177141, Val Loss: 2.044373
    [ Info: Train Loss: 1.7015309, Val Loss: 2.0406048
    [ Info: Train Loss: 1.6836351, Val Loss: 2.0618813
    [ Info: Train Loss: 1.6850606, Val Loss: 2.0612392
    [ Info: Train Loss: 1.7006621, Val Loss: 2.047501
    [ Info: Train Loss: 1.6999978, Val Loss: 2.052405
    [ Info: Train Loss: 1.7009304, Val Loss: 2.0556684
    [ Info: Train Loss: 1.6887848, Val Loss: 2.0540125
    [ Info: Train Loss: 1.6855419, Val Loss: 2.0616734
    [ Info: Train Loss: 1.6916366, Val Loss: 2.0426404
    [ Info: Train Loss: 1.6893617, Val Loss: 2.0559156
    [ Info: Train Loss: 1.6977744, Val Loss: 2.0574336
    [ Info: Train Loss: 1.6968977, Val Loss: 2.0456219
    [ Info: Train Loss: 1.6823503, Val Loss: 2.062682
    [ Info: Train Loss: 1.6991038, Val Loss: 2.051137
    [ Info: Train Loss: 1.6861258, Val Loss: 2.032179
    [ Info: Train Loss: 1.6727321, Val Loss: 2.0421462
    [ Info: Train Loss: 1.6626164, Val Loss: 2.0548327
    [ Info: Train Loss: 1.6945255, Val Loss: 2.041975
    [ Info: Train Loss: 1.6579713, Val Loss: 2.0641537
    [ Info: Train Loss: 1.6617924, Val Loss: 2.058827
    [ Info: Train Loss: 1.6782551, Val Loss: 2.0603712
    [ Info: Train Loss: 1.6792405, Val Loss: 2.0408056
    [ Info: Train Loss: 1.6815602, Val Loss: 2.0537863
    [ Info: Train Loss: 1.6723149, Val Loss: 2.039872
(RNNLM{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, RNN{true, RNNCell{typeof(tanh), Matrix{Float32}, Matrix{Float32}, Vector{Float32}}}, @NamedTuple{num_hiddens::Int64, vocab_size::Int64}}(Dense(32 => 28), RNN(28 => 32, tanh), (num_hiddens = 32, vocab_size = 28)), (val_loss = Float32[2.0083926, 1.9724069, 2.137811, 1.8953362, 2.039872], val_acc = nothing))

Compared with :numref:sec_rnn-scratch, this model achieves comparable perplexity, but runs faster due to the optimized implementations. As before, we can generate predicted tokens following the specified prefix string.

julia
prefix = "it has"
d2lai.prediction(prefix, m, data.vocab, 20, state = zeros(num_hiddens))
┌ Warning: Layer with Float32 parameters got Float64 input.
│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(32 => 28)     # 924 parameters
│   summary(x) = "32×1 Matrix{Float64}"
â”” @ Flux ~/.julia/packages/Flux/3711C/src/layers/stateless.jl:60





"it has dimension of the po"

Summary ​

High-level APIs in deep learning frameworks provide implementations of standard RNNs. These libraries help you to avoid wasting time reimplementing standard models. Moreover, framework implementations are often highly optimized, leading to significant (computational) performance gains when compared with implementations from scratch.

Exercises ​

  1. Can you make the RNN model overfit using the high-level APIs?

  2. Implement the autoregressive model of :numref:sec_sequence using an RNN.

julia