Skip to content

Deep Recurrent Neural Networks ​

julia
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux 
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN
import d2lai: GRUScratch
  Activating project at `/workspace/d2l-julia/d2lai`
    [ Info: Precompiling d2lai [749b8817-cd67-416c-8a57-830ea19f3cc4] (cache misses: include_dependency fsize change (2))
julia

struct StackedRNNScratch{N, A} <: AbstractModel 
    net::N 
    args::A 
end 

Flux.@layer StackedRNNScratch trainable = (net,)

function StackedRNNScratch(num_inputs, num_hiddens, num_layers; sigma = 0.01)
    layers = map(1:num_layers) do  i 
        if i == 1 
            return GRUScratch(num_inputs, num_hiddens; sigma)
        else
            return GRUScratch(num_hiddens, num_hiddens; sigma)
        end
    end
    StackedRNNScratch(layers, construct_nt_args(; num_inputs, num_hiddens, num_layers, sigma))
end
StackedRNNScratch
julia
function (rnn::StackedRNNScratch)(x, state = nothing)
    states = isnothing(state) ? [nothing for _ in 1:length(rnn.net)] : state 
    new_states = []
    for (rnn, st_) in zip(rnn.net, states)
        x, new_st = rnn(x, st_)
        new_states = [new_states; [new_st]]
    end
    x, new_states
end
julia
data = d2lai.TimeMachine(1024, 32) |> f64
num_hiddens = 32
num_layers = 2 
stacked_rnns = StackedRNNScratch(length(data.vocab), num_hiddens, 2)

model = RNNLMScratch(stacked_rnns, length(data.vocab)) |> f64

opt = Descent(4.)
trainer = Trainer(model, data, opt; max_epochs = 50, gpu = true, board_yscale = :identity, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer);
    [ Info: Train Loss: 2.9601334558165466, Val Loss: 2.9517217666830966
    [ Info: Train Loss: 3.083218267768948, Val Loss: 3.0733841376993163
    [ Info: Train Loss: 2.8951596061949756, Val Loss: 2.8912486456702347
    [ Info: Train Loss: 3.201351587714034, Val Loss: 3.175039186751496
    [ Info: Train Loss: 2.718827714649281, Val Loss: 2.720499446484515
    [ Info: Train Loss: 3.8262812256137186, Val Loss: 3.854138404494324
    [ Info: Train Loss: 2.555370779938557, Val Loss: 2.5738663453725583
    [ Info: Train Loss: 2.731417246063586, Val Loss: 2.733773122408332
    [ Info: Train Loss: 2.71089042383076, Val Loss: 2.738603978195346
    [ Info: Train Loss: 2.44044572594663, Val Loss: 2.462528609008958
    [ Info: Train Loss: 2.3516735432179914, Val Loss: 2.4067911823277437
    [ Info: Train Loss: 2.3687606868636766, Val Loss: 2.47717199329585
    [ Info: Train Loss: 2.3588690160925685, Val Loss: 2.4418052758133952
    [ Info: Train Loss: 2.237565755244201, Val Loss: 2.3642804555051993
    [ Info: Train Loss: 2.235600284005974, Val Loss: 2.3784332127024483
    [ Info: Train Loss: 2.142417782203248, Val Loss: 2.3473351758904135
    [ Info: Train Loss: 2.093294573725641, Val Loss: 2.3372173455936847
    [ Info: Train Loss: 2.0842137555451457, Val Loss: 2.279965692638798
    [ Info: Train Loss: 2.0549474983694065, Val Loss: 2.2777371572671283
    [ Info: Train Loss: 2.0442055475662606, Val Loss: 2.3023100889111268
    [ Info: Train Loss: 1.9507728314106645, Val Loss: 2.1966278901863476
    [ Info: Train Loss: 1.9032525890733238, Val Loss: 2.1696957569150443
    [ Info: Train Loss: 1.875100311762148, Val Loss: 2.163587692999274
    [ Info: Train Loss: 1.9025953221921292, Val Loss: 2.178660075334394
    [ Info: Train Loss: 1.8476598491168619, Val Loss: 2.106155456126516
    [ Info: Train Loss: 1.7921675184256025, Val Loss: 2.135495719844115
    [ Info: Train Loss: 1.8056898625599576, Val Loss: 2.1558411177375083
    [ Info: Train Loss: 1.7371011036515749, Val Loss: 2.1562540780961754
    [ Info: Train Loss: 1.7239404400320422, Val Loss: 2.125930291733888
    [ Info: Train Loss: 1.659753328312827, Val Loss: 2.148482260531729
    [ Info: Train Loss: 1.651105685856042, Val Loss: 2.1208701279929336
    [ Info: Train Loss: 1.6196358531986192, Val Loss: 2.1608039697391925
    [ Info: Train Loss: 1.6178051380319713, Val Loss: 2.1488381415850943
    [ Info: Train Loss: 1.6018263977563578, Val Loss: 2.158062634013532
    [ Info: Train Loss: 1.5643781226789049, Val Loss: 2.1945494274982638
    [ Info: Train Loss: 1.5663794729658038, Val Loss: 2.1729942268961193
    [ Info: Train Loss: 1.5039514623306511, Val Loss: 2.163342152744304
    [ Info: Train Loss: 1.5164812564868821, Val Loss: 2.139802582799324
    [ Info: Train Loss: 1.5037808116243674, Val Loss: 2.1911284154473316
    [ Info: Train Loss: 1.4939950577778571, Val Loss: 2.1633152335120442
    [ Info: Train Loss: 1.4915562326575769, Val Loss: 2.21846338539298
    [ Info: Train Loss: 1.4469348669530806, Val Loss: 2.095370027284382
    [ Info: Train Loss: 1.4274980833798925, Val Loss: 2.1127067080767774
    [ Info: Train Loss: 1.4072615721373687, Val Loss: 2.182247864599464
    [ Info: Train Loss: 1.4257045378597692, Val Loss: 2.218274727056704
    [ Info: Train Loss: 1.4431864775869827, Val Loss: 2.187672692860164
    [ Info: Train Loss: 1.4082344801302875, Val Loss: 2.1680913003725406
    [ Info: Train Loss: 1.3714887976310248, Val Loss: 2.2072342000236693
    [ Info: Train Loss: 1.3759805603462967, Val Loss: 2.2025357399595213
    [ Info: Train Loss: 1.3484777880112848, Val Loss: 2.1770381725964487
julia
prefix = "it has"
d2lai.prediction(prefix, m, data.vocab, 20)
"it has explere and have no"
julia
struct StackedRNN{N, A} <: AbstractModel 
    net::N 
    args::A 
end 
Flux.@layer StackedRNN trainable = (net,)
function StackedRNN(num_inputs, num_hiddens, num_layers)
    layers = map(1:num_layers) do i
        if i==1 
            return GRU(num_inputs => num_hiddens; return_state = true)
        else
            return GRU(num_hiddens => num_hiddens; return_state = true)
        end
    end
    StackedRNN(layers, construct_nt_args(; num_inputs, num_hiddens, num_layers))
end

function (rnn::StackedRNN)(x, state = nothing)
    states = isnothing(state) ? [Flux.initialstates(n) for n in rnn.net] : state 
    new_states = []
    for (m, st_) in zip(rnn.net, states)
        x, new_st = m(x, st_)
        new_states = [new_states; [new_st]]
    end
    x, new_states
end
julia
srnn = StackedRNN(length(data.vocab), num_hiddens, 2) |> f64
model = RNNModelConcise(srnn, num_hiddens, length(data.vocab)) |> f64

opt = Flux.Optimiser(Descent(2.))
trainer = Trainer(model, data, opt; max_epochs = 100, gpu = true, board_yscale = :identity, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer);
    [ Info: Train Loss: 3.0729582, Val Loss: 3.0826774
    [ Info: Train Loss: 2.9488046, Val Loss: 2.9387672
    [ Info: Train Loss: 2.8525553, Val Loss: 2.8506098
    [ Info: Train Loss: 2.6630747, Val Loss: 2.6527565
    [ Info: Train Loss: 2.5859373, Val Loss: 2.5650566
    [ Info: Train Loss: 2.4740438, Val Loss: 2.514804
    [ Info: Train Loss: 2.371892, Val Loss: 2.394975
    [ Info: Train Loss: 2.354868, Val Loss: 2.4037442
    [ Info: Train Loss: 2.254007, Val Loss: 2.3529797
    [ Info: Train Loss: 2.2311308, Val Loss: 2.3084953
    [ Info: Train Loss: 2.147948, Val Loss: 2.234294
    [ Info: Train Loss: 2.0833817, Val Loss: 2.2096674
    [ Info: Train Loss: 2.0624454, Val Loss: 2.2006884
    [ Info: Train Loss: 2.0190148, Val Loss: 2.141577
    [ Info: Train Loss: 1.9770054, Val Loss: 2.1524167
    [ Info: Train Loss: 1.9207356, Val Loss: 2.123338
    [ Info: Train Loss: 1.8788599, Val Loss: 2.0964837
    [ Info: Train Loss: 1.8548547, Val Loss: 2.0287824
    [ Info: Train Loss: 1.7822961, Val Loss: 2.0009334
    [ Info: Train Loss: 1.7661173, Val Loss: 2.0087724
    [ Info: Train Loss: 1.7366441, Val Loss: 2.009952
    [ Info: Train Loss: 1.6795175, Val Loss: 1.984611
    [ Info: Train Loss: 1.6710591, Val Loss: 1.9936811
    [ Info: Train Loss: 1.6375403, Val Loss: 1.9641424
    [ Info: Train Loss: 1.623228, Val Loss: 2.0116076
    [ Info: Train Loss: 1.6005237, Val Loss: 1.9941324
    [ Info: Train Loss: 1.5355363, Val Loss: 1.9425739
    [ Info: Train Loss: 1.5244751, Val Loss: 1.9511333
    [ Info: Train Loss: 1.4960856, Val Loss: 1.9564431
    [ Info: Train Loss: 1.4842346, Val Loss: 1.9604983
    [ Info: Train Loss: 1.4594462, Val Loss: 1.9397213
    [ Info: Train Loss: 1.4230841, Val Loss: 1.9380797
    [ Info: Train Loss: 1.4064045, Val Loss: 1.9675212
    [ Info: Train Loss: 1.3824363, Val Loss: 1.9790069
    [ Info: Train Loss: 1.3710134, Val Loss: 1.9749012
    [ Info: Train Loss: 1.3684355, Val Loss: 2.009786
    [ Info: Train Loss: 1.3642248, Val Loss: 1.981412
    [ Info: Train Loss: 1.3158046, Val Loss: 2.0184858
    [ Info: Train Loss: 1.336281, Val Loss: 2.035481
    [ Info: Train Loss: 1.3019139, Val Loss: 2.0148616
    [ Info: Train Loss: 1.2883389, Val Loss: 2.0252836
    [ Info: Train Loss: 1.2609012, Val Loss: 2.0469217
    [ Info: Train Loss: 1.2739772, Val Loss: 2.053692
    [ Info: Train Loss: 1.2369365, Val Loss: 2.0801237
    [ Info: Train Loss: 1.2299112, Val Loss: 2.0566382
    [ Info: Train Loss: 1.2436033, Val Loss: 2.0560641
    [ Info: Train Loss: 1.1817329, Val Loss: 2.0845814
    [ Info: Train Loss: 1.1996968, Val Loss: 2.1041818
    [ Info: Train Loss: 1.1991175, Val Loss: 2.0899494
    [ Info: Train Loss: 1.1808488, Val Loss: 2.111517
    [ Info: Train Loss: 1.1769359, Val Loss: 2.1382341
    [ Info: Train Loss: 1.1749823, Val Loss: 2.1326122
    [ Info: Train Loss: 1.1592147, Val Loss: 2.121256
    [ Info: Train Loss: 1.1497483, Val Loss: 2.1279364
    [ Info: Train Loss: 1.1569502, Val Loss: 2.1380363
    [ Info: Train Loss: 1.1212716, Val Loss: 2.1456661
    [ Info: Train Loss: 1.1321361, Val Loss: 2.1739123
    [ Info: Train Loss: 1.1194346, Val Loss: 2.1772566
    [ Info: Train Loss: 1.1209718, Val Loss: 2.2002127
    [ Info: Train Loss: 1.1135013, Val Loss: 2.1820202
    [ Info: Train Loss: 1.0901783, Val Loss: 2.18443
    [ Info: Train Loss: 1.0887899, Val Loss: 2.2044551
    [ Info: Train Loss: 1.0749239, Val Loss: 2.1647375
    [ Info: Train Loss: 1.0597292, Val Loss: 2.2246313
    [ Info: Train Loss: 1.097184, Val Loss: 2.244104
    [ Info: Train Loss: 1.0836931, Val Loss: 2.2037764
    [ Info: Train Loss: 1.0905948, Val Loss: 2.2060766
    [ Info: Train Loss: 1.0692631, Val Loss: 2.2625248
    [ Info: Train Loss: 1.0678359, Val Loss: 2.2287636
    [ Info: Train Loss: 1.0676397, Val Loss: 2.247726
    [ Info: Train Loss: 1.0601895, Val Loss: 2.2661674
    [ Info: Train Loss: 1.0503038, Val Loss: 2.2502162
    [ Info: Train Loss: 1.0421187, Val Loss: 2.2681298
    [ Info: Train Loss: 1.0481176, Val Loss: 2.2715967
    [ Info: Train Loss: 1.0424931, Val Loss: 2.2882733
    [ Info: Train Loss: 1.0234843, Val Loss: 2.2857475
    [ Info: Train Loss: 1.0312903, Val Loss: 2.3094137
    [ Info: Train Loss: 1.0422728, Val Loss: 2.3331096
    [ Info: Train Loss: 1.0208472, Val Loss: 2.3397734
    [ Info: Train Loss: 1.0243846, Val Loss: 2.2890894
    [ Info: Train Loss: 1.0057296, Val Loss: 2.3031518
    [ Info: Train Loss: 1.0093327, Val Loss: 2.3217697
    [ Info: Train Loss: 1.0281, Val Loss: 2.2965145
    [ Info: Train Loss: 1.0152854, Val Loss: 2.3509705
    [ Info: Train Loss: 1.0198553, Val Loss: 2.3523006
    [ Info: Train Loss: 0.99818, Val Loss: 2.3264742
    [ Info: Train Loss: 1.0130087, Val Loss: 2.3380609
    [ Info: Train Loss: 1.0043898, Val Loss: 2.3419886
    [ Info: Train Loss: 0.99594694, Val Loss: 2.348446
    [ Info: Train Loss: 0.9855084, Val Loss: 2.346001
    [ Info: Train Loss: 0.97566634, Val Loss: 2.333736
    [ Info: Train Loss: 0.9899276, Val Loss: 2.357852
    [ Info: Train Loss: 0.97932696, Val Loss: 2.3778205
    [ Info: Train Loss: 0.96937585, Val Loss: 2.351455
    [ Info: Train Loss: 0.9782952, Val Loss: 2.363977
    [ Info: Train Loss: 0.974543, Val Loss: 2.3918114
    [ Info: Train Loss: 0.9872866, Val Loss: 2.376842
    [ Info: Train Loss: 0.9691477, Val Loss: 2.3819165
    [ Info: Train Loss: 0.96983945, Val Loss: 2.388229
    [ Info: Train Loss: 0.95775, Val Loss: 2.4141488
julia
d2lai.prediction(prefix, m, data.vocab, 20)
"it has of staying is all r"
julia