Deep Recurrent Neural Networks ​
julia
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN
import d2lai: GRUScratch Activating project at `/workspace/d2l-julia/d2lai`[ Info: Precompiling d2lai [749b8817-cd67-416c-8a57-830ea19f3cc4] (cache misses: include_dependency fsize change (2))
julia
struct StackedRNNScratch{N, A} <: AbstractModel
net::N
args::A
end
Flux.@layer StackedRNNScratch trainable = (net,)
function StackedRNNScratch(num_inputs, num_hiddens, num_layers; sigma = 0.01)
layers = map(1:num_layers) do i
if i == 1
return GRUScratch(num_inputs, num_hiddens; sigma)
else
return GRUScratch(num_hiddens, num_hiddens; sigma)
end
end
StackedRNNScratch(layers, construct_nt_args(; num_inputs, num_hiddens, num_layers, sigma))
endStackedRNNScratchjulia
function (rnn::StackedRNNScratch)(x, state = nothing)
states = isnothing(state) ? [nothing for _ in 1:length(rnn.net)] : state
new_states = []
for (rnn, st_) in zip(rnn.net, states)
x, new_st = rnn(x, st_)
new_states = [new_states; [new_st]]
end
x, new_states
endjulia
data = d2lai.TimeMachine(1024, 32) |> f64
num_hiddens = 32
num_layers = 2
stacked_rnns = StackedRNNScratch(length(data.vocab), num_hiddens, 2)
model = RNNLMScratch(stacked_rnns, length(data.vocab)) |> f64
opt = Descent(4.)
trainer = Trainer(model, data, opt; max_epochs = 50, gpu = true, board_yscale = :identity, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer); [ Info: Train Loss: 2.9601334558165466, Val Loss: 2.9517217666830966
[ Info: Train Loss: 3.083218267768948, Val Loss: 3.0733841376993163
[ Info: Train Loss: 2.8951596061949756, Val Loss: 2.8912486456702347
[ Info: Train Loss: 3.201351587714034, Val Loss: 3.175039186751496
[ Info: Train Loss: 2.718827714649281, Val Loss: 2.720499446484515
[ Info: Train Loss: 3.8262812256137186, Val Loss: 3.854138404494324
[ Info: Train Loss: 2.555370779938557, Val Loss: 2.5738663453725583
[ Info: Train Loss: 2.731417246063586, Val Loss: 2.733773122408332
[ Info: Train Loss: 2.71089042383076, Val Loss: 2.738603978195346
[ Info: Train Loss: 2.44044572594663, Val Loss: 2.462528609008958
[ Info: Train Loss: 2.3516735432179914, Val Loss: 2.4067911823277437
[ Info: Train Loss: 2.3687606868636766, Val Loss: 2.47717199329585
[ Info: Train Loss: 2.3588690160925685, Val Loss: 2.4418052758133952
[ Info: Train Loss: 2.237565755244201, Val Loss: 2.3642804555051993
[ Info: Train Loss: 2.235600284005974, Val Loss: 2.3784332127024483
[ Info: Train Loss: 2.142417782203248, Val Loss: 2.3473351758904135
[ Info: Train Loss: 2.093294573725641, Val Loss: 2.3372173455936847
[ Info: Train Loss: 2.0842137555451457, Val Loss: 2.279965692638798
[ Info: Train Loss: 2.0549474983694065, Val Loss: 2.2777371572671283
[ Info: Train Loss: 2.0442055475662606, Val Loss: 2.3023100889111268
[ Info: Train Loss: 1.9507728314106645, Val Loss: 2.1966278901863476
[ Info: Train Loss: 1.9032525890733238, Val Loss: 2.1696957569150443
[ Info: Train Loss: 1.875100311762148, Val Loss: 2.163587692999274
[ Info: Train Loss: 1.9025953221921292, Val Loss: 2.178660075334394
[ Info: Train Loss: 1.8476598491168619, Val Loss: 2.106155456126516
[ Info: Train Loss: 1.7921675184256025, Val Loss: 2.135495719844115
[ Info: Train Loss: 1.8056898625599576, Val Loss: 2.1558411177375083
[ Info: Train Loss: 1.7371011036515749, Val Loss: 2.1562540780961754
[ Info: Train Loss: 1.7239404400320422, Val Loss: 2.125930291733888
[ Info: Train Loss: 1.659753328312827, Val Loss: 2.148482260531729
[ Info: Train Loss: 1.651105685856042, Val Loss: 2.1208701279929336
[ Info: Train Loss: 1.6196358531986192, Val Loss: 2.1608039697391925
[ Info: Train Loss: 1.6178051380319713, Val Loss: 2.1488381415850943
[ Info: Train Loss: 1.6018263977563578, Val Loss: 2.158062634013532
[ Info: Train Loss: 1.5643781226789049, Val Loss: 2.1945494274982638
[ Info: Train Loss: 1.5663794729658038, Val Loss: 2.1729942268961193
[ Info: Train Loss: 1.5039514623306511, Val Loss: 2.163342152744304
[ Info: Train Loss: 1.5164812564868821, Val Loss: 2.139802582799324
[ Info: Train Loss: 1.5037808116243674, Val Loss: 2.1911284154473316
[ Info: Train Loss: 1.4939950577778571, Val Loss: 2.1633152335120442
[ Info: Train Loss: 1.4915562326575769, Val Loss: 2.21846338539298
[ Info: Train Loss: 1.4469348669530806, Val Loss: 2.095370027284382
[ Info: Train Loss: 1.4274980833798925, Val Loss: 2.1127067080767774
[ Info: Train Loss: 1.4072615721373687, Val Loss: 2.182247864599464
[ Info: Train Loss: 1.4257045378597692, Val Loss: 2.218274727056704
[ Info: Train Loss: 1.4431864775869827, Val Loss: 2.187672692860164
[ Info: Train Loss: 1.4082344801302875, Val Loss: 2.1680913003725406
[ Info: Train Loss: 1.3714887976310248, Val Loss: 2.2072342000236693
[ Info: Train Loss: 1.3759805603462967, Val Loss: 2.2025357399595213
[ Info: Train Loss: 1.3484777880112848, Val Loss: 2.1770381725964487julia
prefix = "it has"
d2lai.prediction(prefix, m, data.vocab, 20)"it has explere and have no"julia
struct StackedRNN{N, A} <: AbstractModel
net::N
args::A
end
Flux.@layer StackedRNN trainable = (net,)
function StackedRNN(num_inputs, num_hiddens, num_layers)
layers = map(1:num_layers) do i
if i==1
return GRU(num_inputs => num_hiddens; return_state = true)
else
return GRU(num_hiddens => num_hiddens; return_state = true)
end
end
StackedRNN(layers, construct_nt_args(; num_inputs, num_hiddens, num_layers))
end
function (rnn::StackedRNN)(x, state = nothing)
states = isnothing(state) ? [Flux.initialstates(n) for n in rnn.net] : state
new_states = []
for (m, st_) in zip(rnn.net, states)
x, new_st = m(x, st_)
new_states = [new_states; [new_st]]
end
x, new_states
endjulia
srnn = StackedRNN(length(data.vocab), num_hiddens, 2) |> f64
model = RNNModelConcise(srnn, num_hiddens, length(data.vocab)) |> f64
opt = Flux.Optimiser(Descent(2.))
trainer = Trainer(model, data, opt; max_epochs = 100, gpu = true, board_yscale = :identity, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer); [ Info: Train Loss: 3.0729582, Val Loss: 3.0826774
[ Info: Train Loss: 2.9488046, Val Loss: 2.9387672
[ Info: Train Loss: 2.8525553, Val Loss: 2.8506098
[ Info: Train Loss: 2.6630747, Val Loss: 2.6527565
[ Info: Train Loss: 2.5859373, Val Loss: 2.5650566
[ Info: Train Loss: 2.4740438, Val Loss: 2.514804
[ Info: Train Loss: 2.371892, Val Loss: 2.394975
[ Info: Train Loss: 2.354868, Val Loss: 2.4037442
[ Info: Train Loss: 2.254007, Val Loss: 2.3529797
[ Info: Train Loss: 2.2311308, Val Loss: 2.3084953
[ Info: Train Loss: 2.147948, Val Loss: 2.234294
[ Info: Train Loss: 2.0833817, Val Loss: 2.2096674
[ Info: Train Loss: 2.0624454, Val Loss: 2.2006884
[ Info: Train Loss: 2.0190148, Val Loss: 2.141577
[ Info: Train Loss: 1.9770054, Val Loss: 2.1524167
[ Info: Train Loss: 1.9207356, Val Loss: 2.123338
[ Info: Train Loss: 1.8788599, Val Loss: 2.0964837
[ Info: Train Loss: 1.8548547, Val Loss: 2.0287824
[ Info: Train Loss: 1.7822961, Val Loss: 2.0009334
[ Info: Train Loss: 1.7661173, Val Loss: 2.0087724
[ Info: Train Loss: 1.7366441, Val Loss: 2.009952
[ Info: Train Loss: 1.6795175, Val Loss: 1.984611
[ Info: Train Loss: 1.6710591, Val Loss: 1.9936811
[ Info: Train Loss: 1.6375403, Val Loss: 1.9641424
[ Info: Train Loss: 1.623228, Val Loss: 2.0116076
[ Info: Train Loss: 1.6005237, Val Loss: 1.9941324
[ Info: Train Loss: 1.5355363, Val Loss: 1.9425739
[ Info: Train Loss: 1.5244751, Val Loss: 1.9511333
[ Info: Train Loss: 1.4960856, Val Loss: 1.9564431
[ Info: Train Loss: 1.4842346, Val Loss: 1.9604983
[ Info: Train Loss: 1.4594462, Val Loss: 1.9397213
[ Info: Train Loss: 1.4230841, Val Loss: 1.9380797
[ Info: Train Loss: 1.4064045, Val Loss: 1.9675212
[ Info: Train Loss: 1.3824363, Val Loss: 1.9790069
[ Info: Train Loss: 1.3710134, Val Loss: 1.9749012
[ Info: Train Loss: 1.3684355, Val Loss: 2.009786
[ Info: Train Loss: 1.3642248, Val Loss: 1.981412
[ Info: Train Loss: 1.3158046, Val Loss: 2.0184858
[ Info: Train Loss: 1.336281, Val Loss: 2.035481
[ Info: Train Loss: 1.3019139, Val Loss: 2.0148616
[ Info: Train Loss: 1.2883389, Val Loss: 2.0252836
[ Info: Train Loss: 1.2609012, Val Loss: 2.0469217
[ Info: Train Loss: 1.2739772, Val Loss: 2.053692
[ Info: Train Loss: 1.2369365, Val Loss: 2.0801237
[ Info: Train Loss: 1.2299112, Val Loss: 2.0566382
[ Info: Train Loss: 1.2436033, Val Loss: 2.0560641
[ Info: Train Loss: 1.1817329, Val Loss: 2.0845814
[ Info: Train Loss: 1.1996968, Val Loss: 2.1041818
[ Info: Train Loss: 1.1991175, Val Loss: 2.0899494
[ Info: Train Loss: 1.1808488, Val Loss: 2.111517
[ Info: Train Loss: 1.1769359, Val Loss: 2.1382341
[ Info: Train Loss: 1.1749823, Val Loss: 2.1326122
[ Info: Train Loss: 1.1592147, Val Loss: 2.121256
[ Info: Train Loss: 1.1497483, Val Loss: 2.1279364
[ Info: Train Loss: 1.1569502, Val Loss: 2.1380363
[ Info: Train Loss: 1.1212716, Val Loss: 2.1456661
[ Info: Train Loss: 1.1321361, Val Loss: 2.1739123
[ Info: Train Loss: 1.1194346, Val Loss: 2.1772566
[ Info: Train Loss: 1.1209718, Val Loss: 2.2002127
[ Info: Train Loss: 1.1135013, Val Loss: 2.1820202
[ Info: Train Loss: 1.0901783, Val Loss: 2.18443
[ Info: Train Loss: 1.0887899, Val Loss: 2.2044551
[ Info: Train Loss: 1.0749239, Val Loss: 2.1647375
[ Info: Train Loss: 1.0597292, Val Loss: 2.2246313
[ Info: Train Loss: 1.097184, Val Loss: 2.244104
[ Info: Train Loss: 1.0836931, Val Loss: 2.2037764
[ Info: Train Loss: 1.0905948, Val Loss: 2.2060766
[ Info: Train Loss: 1.0692631, Val Loss: 2.2625248
[ Info: Train Loss: 1.0678359, Val Loss: 2.2287636
[ Info: Train Loss: 1.0676397, Val Loss: 2.247726
[ Info: Train Loss: 1.0601895, Val Loss: 2.2661674
[ Info: Train Loss: 1.0503038, Val Loss: 2.2502162
[ Info: Train Loss: 1.0421187, Val Loss: 2.2681298
[ Info: Train Loss: 1.0481176, Val Loss: 2.2715967
[ Info: Train Loss: 1.0424931, Val Loss: 2.2882733
[ Info: Train Loss: 1.0234843, Val Loss: 2.2857475
[ Info: Train Loss: 1.0312903, Val Loss: 2.3094137
[ Info: Train Loss: 1.0422728, Val Loss: 2.3331096
[ Info: Train Loss: 1.0208472, Val Loss: 2.3397734
[ Info: Train Loss: 1.0243846, Val Loss: 2.2890894
[ Info: Train Loss: 1.0057296, Val Loss: 2.3031518
[ Info: Train Loss: 1.0093327, Val Loss: 2.3217697
[ Info: Train Loss: 1.0281, Val Loss: 2.2965145
[ Info: Train Loss: 1.0152854, Val Loss: 2.3509705
[ Info: Train Loss: 1.0198553, Val Loss: 2.3523006
[ Info: Train Loss: 0.99818, Val Loss: 2.3264742
[ Info: Train Loss: 1.0130087, Val Loss: 2.3380609
[ Info: Train Loss: 1.0043898, Val Loss: 2.3419886
[ Info: Train Loss: 0.99594694, Val Loss: 2.348446
[ Info: Train Loss: 0.9855084, Val Loss: 2.346001
[ Info: Train Loss: 0.97566634, Val Loss: 2.333736
[ Info: Train Loss: 0.9899276, Val Loss: 2.357852
[ Info: Train Loss: 0.97932696, Val Loss: 2.3778205
[ Info: Train Loss: 0.96937585, Val Loss: 2.351455
[ Info: Train Loss: 0.9782952, Val Loss: 2.363977
[ Info: Train Loss: 0.974543, Val Loss: 2.3918114
[ Info: Train Loss: 0.9872866, Val Loss: 2.376842
[ Info: Train Loss: 0.9691477, Val Loss: 2.3819165
[ Info: Train Loss: 0.96983945, Val Loss: 2.388229
[ Info: Train Loss: 0.95775, Val Loss: 2.4141488julia
d2lai.prediction(prefix, m, data.vocab, 20)"it has of staying is all r"julia