Bidirectional Recurrent Neural Networks
So far, our working example of a sequence learning task has been language modeling, where we aim to predict the next token given all previous tokens in a sequence. In this scenario, we wish only to condition upon the leftward context, and thus the unidirectional chaining of a standard RNN seems appropriate. However, there are many other sequence learning tasks contexts where it is perfectly fine to condition the prediction at every time step on both the leftward and the rightward context. Consider, for example, part of speech detection. Why shouldn't we take the context in both directions into account when assessing the part of speech associated with a given word?
Another common task–-often useful as a pretraining exercise prior to fine-tuning a model on an actual task of interest–-is to mask out random tokens in a text document and then to train a sequence model to predict the values of the missing tokens. Note that depending on what comes after the blank, the likely value of the missing token changes dramatically:
I am
___.I am
___hungry.I am
___hungry, and I can eat half a pig.
In the first sentence "happy" seems to be a likely candidate. The words "not" and "very" seem plausible in the second sentence, but "not" seems incompatible with the third sentences.
Fortunately, a simple technique transforms any unidirectional RNN into a bidirectional RNN [169]. We simply implement two unidirectional RNN layers chained together in opposite directions and acting on the same input (Figure). For the first RNN layer, the first input is
Architecture of a bidirectional RNN.
Formally for any time step
where the weights
Next, we concatenate the forward and backward hidden states
Here, the weight matrix
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN
import d2lai: RNNScratch Activating project at `/workspace/d2l-julia/d2lai`Implementation from Scratch
If we want to implement a bidirectional RNN from scratch, we can include two unidirectional RNNScratch instances with separate learnable parameters.
struct BiRNNScratch{N, A} <: AbstractModel
net::N
args::A
end
Flux.@layer BiRNNScratch trainable=(net,)
function BiRNNScratch(num_inputs::Int, num_hiddens::Int; sigma = 0.01)
frnn = RNNScratch(num_inputs, num_hiddens; sigma)
brnn = RNNScratch(num_inputs, num_hiddens; sigma)
BiRNNScratch((; frnn, brnn), (num_hiddens = num_hiddens*2 ,num_inputs, sigma))
endStates of forward and backward RNNs are updated separately, while outputs of these two RNNs are concatenated.
function (m::BiRNNScratch)(x, state = nothing)
f_state, b_state = isnothing(state) ? (nothing, nothing) : state
out_f, f_state = m.net.frnn(x, f_state)
out_b, b_state = m.net.brnn(reverse(x, dims = 2), b_state)
out = cat(out_f, reverse(out_b, dims = 2), dims = 1)
return out, (f_state, b_state)
endConcise Implementation
Using the high-level APIs, we can implement bidirectional RNNs more concisely. Here we take a GRU model as an example.
model = GRU(num_inputs => num_hiddens, bidirectional = true)