Skip to content

Bidirectional Recurrent Neural Networks

So far, our working example of a sequence learning task has been language modeling, where we aim to predict the next token given all previous tokens in a sequence. In this scenario, we wish only to condition upon the leftward context, and thus the unidirectional chaining of a standard RNN seems appropriate. However, there are many other sequence learning tasks contexts where it is perfectly fine to condition the prediction at every time step on both the leftward and the rightward context. Consider, for example, part of speech detection. Why shouldn't we take the context in both directions into account when assessing the part of speech associated with a given word?

Another common task–-often useful as a pretraining exercise prior to fine-tuning a model on an actual task of interest–-is to mask out random tokens in a text document and then to train a sequence model to predict the values of the missing tokens. Note that depending on what comes after the blank, the likely value of the missing token changes dramatically:

  • I am ___.

  • I am ___ hungry.

  • I am ___ hungry, and I can eat half a pig.

In the first sentence "happy" seems to be a likely candidate. The words "not" and "very" seem plausible in the second sentence, but "not" seems incompatible with the third sentences.

Fortunately, a simple technique transforms any unidirectional RNN into a bidirectional RNN [169]. We simply implement two unidirectional RNN layers chained together in opposite directions and acting on the same input (Figure). For the first RNN layer, the first input is x1 and the last input is xT, but for the second RNN layer, the first input is xT and the last input is x1. To produce the output of this bidirectional RNN layer, we simply concatenate together the corresponding outputs of the two underlying unidirectional RNN layers.

Architecture of a bidirectional RNN.

Formally for any time step t, we consider a minibatch input XtRn×d (number of examples =n; number of inputs in each example =d) and let the hidden layer activation function be ϕ. In the bidirectional architecture, the forward and backward hidden states for this time step are HtRn×h and HtRn×h, respectively, where h is the number of hidden units. The forward and backward hidden state updates are as follows:

Ht=ϕ(XtWxh(f)+Ht1Whh(f)+bh(f)),Ht=ϕ(XtWxh(b)+Ht+1Whh(b)+bh(b)),

where the weights Wxh(f)Rd×h,Whh(f)Rh×h,Wxh(b)Rd×h, and Whh(b)Rh×h, and the biases bh(f)R1×h and bh(b)R1×h are all the model parameters.

Next, we concatenate the forward and backward hidden states Ht and Ht to obtain the hidden state HtRn×2h for feeding into the output layer. In deep bidirectional RNNs with multiple hidden layers, such information is passed on as input to the next bidirectional layer. Last, the output layer computes the output OtRn×q (number of outputs =q):

Ot=HtWhq+bq.

Here, the weight matrix WhqR2h×q and the bias bqR1×q are the model parameters of the output layer. While technically, the two directions can have different numbers of hidden units, this design choice is seldom made in practice. We now demonstrate a simple implementation

julia
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux 
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN
import d2lai: RNNScratch
  Activating project at `/workspace/d2l-julia/d2lai`

Implementation from Scratch

If we want to implement a bidirectional RNN from scratch, we can include two unidirectional RNNScratch instances with separate learnable parameters.

julia
struct BiRNNScratch{N, A} <: AbstractModel
    net::N
    args::A
end 

Flux.@layer BiRNNScratch trainable=(net,)

function BiRNNScratch(num_inputs::Int, num_hiddens::Int; sigma = 0.01)
    frnn = RNNScratch(num_inputs, num_hiddens; sigma)
    brnn = RNNScratch(num_inputs, num_hiddens; sigma)
    BiRNNScratch((; frnn, brnn), (num_hiddens = num_hiddens*2 ,num_inputs, sigma))
end

States of forward and backward RNNs are updated separately, while outputs of these two RNNs are concatenated.

julia
function (m::BiRNNScratch)(x, state = nothing)
    f_state, b_state = isnothing(state) ? (nothing, nothing) : state
    out_f, f_state = m.net.frnn(x, f_state)
    out_b, b_state = m.net.brnn(reverse(x, dims = 2), b_state)
    out = cat(out_f, reverse(out_b, dims = 2), dims = 1)
    return out, (f_state, b_state)
end

Concise Implementation

Using the high-level APIs, we can implement bidirectional RNNs more concisely. Here we take a GRU model as an example.

julia
model = GRU(num_inputs => num_hiddens, bidirectional = true)
julia