The Bahdanau Attention Mechanism ​
When we encountered machine translation in :numref:sec_seq2seq, we designed an encoder–decoder architecture for sequence-to-sequence learning based on two RNNs [153]. Specifically, the RNN encoder transforms a variable-length sequence into a fixed-shape context variable. Then, the RNN decoder generates the output (target) sequence token by token based on the generated tokens and the context variable.
Recall Figure which we repeat (Figure) with some additional detail. Conventionally, in an RNN all relevant information about a source sequence is translated into some internal fixed-dimensional state representation by the encoder. It is this very state that is used by the decoder as the complete and exclusive source of information for generating the translated sequence. In other words, the sequence-to-sequence mechanism treats the intermediate state as a sufficient statistic of whatever string might have served as input.
Sequence-to-sequence model. The state, as generated by the encoder, is the only piece of information shared between the encoder and the decoder.
While this is quite reasonable for short sequences, it is clear that it is infeasible for long ones, such as a book chapter or even just a very long sentence. After all, before too long there will simply not be enough "space" in the intermediate representation to store all that is important in the source sequence. Consequently the decoder will fail to translate long and complex sentences. One of the first to encounter this was Graves [193] who tried to design an RNN to generate handwritten text. Since the source text has arbitrary length they designed a differentiable attention model to align text characters with the much longer pen trace, where the alignment moves only in one direction. This, in turn, draws on decoding algorithms in speech recognition, e.g., hidden Markov models [194].
Inspired by the idea of learning to align, Bahdanau et al. [183] proposed a differentiable attention model without the unidirectional alignment limitation. When predicting a token, if not all the input tokens are relevant, the model aligns (or attends) only to parts of the input sequence that are deemed relevant to the current prediction. This is then used to update the current state before generating the next token. While quite innocuous in its description, this Bahdanau attention mechanism has arguably turned into one of the most influential ideas of the past decade in deep learning, giving rise to Transformers [142] and many related new architectures.
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using Plots
using CUDA, cuDNN
import d2lai: Seq2Seq, AbstractEncoderDecoder, Seq2SeqEncoder, AdditiveAttention, StackedRNN Activating project at `/workspace/workspace/d2l-julia/d2lai`Model ​
We follow the notation introduced by the sequence-to-sequence architecture of :numref:sec_seq2seq, in particular :eqref:eq_seq2seq_s_t. The key idea is that instead of keeping the state, i.e., the context variable
We used eq_seq2seq_s_t. In particular, the attention weight eq_attn-scoring-alpha using the additive attention scoring function defined by :eqref:eq_additive-attn. This RNN encoder–decoder architecture using attention is depicted in Figure. Note that later this model was modified so as to include the already generated tokens in the decoder as further context (i.e., the attention sum does not stop at
Layers in an RNN encoder–decoder model with the Bahdanau attention mechanism.
Defining the Decoder with Attention ​
To implement the RNN encoder–decoder with attention, we only need to redefine the decoder (omitting the generated symbols from the attention function simplifies the design). Let's begin with the base interface for decoders with attention by defining the quite unsurprisingly named AttentionDecoder class.
abstract type AttentionDecoder <: AbstractModel endWe need to implement the RNN decoder in the Seq2SeqAttentionDecoder struct. The state of the decoder is initialized with (i) the hidden states of the last layer of the encoder at all time steps, used as keys and values for attention; (ii) the hidden state of the encoder at all layers at the final time step, which serves to initialize the hidden state of the decoder; and (iii) the valid length of the encoder, to exclude the padding tokens in attention pooling. At each decoding time step, the hidden state of the final layer of the decoder, obtained at the previous time step, is used as the query of the attention mechanism. Both the output of the attention mechanism and the input embedding are concatenated to serve as the input of the RNN decoder.
struct Seq2SeqAttentionDecoder{AT, E, R, D, A} <: AbstractModel
attention::AT
embedding::E
rnn::R
dense::D
args::A
end
Flux.@layer Seq2SeqAttentionDecoder trainable = (attention, embedding, rnn, dense)
function d2lai.init_state(::Seq2SeqAttentionDecoder, enc_all_out, enc_valid_lens)
outputs, hidden_state = enc_all_out
return outputs, hidden_state, enc_valid_lens
end
function Seq2SeqAttentionDecoder(vocab_size::Int, embed_size::Int, num_hiddens, num_layers, dropout=0.)
embedding = Embedding(vocab_size => embed_size)
rnn = StackedRNN(embed_size + num_hiddens, num_hiddens, num_layers; rnn = Flux.LSTM)
attention = AdditiveAttention(num_hiddens, num_hiddens, num_hiddens, num_hiddens, dropout)
dense = Dense(num_hiddens, vocab_size)
args = (; vocab_size, embed_size, num_hiddens, num_layers)
Seq2SeqAttentionDecoder(attention, embedding, rnn, dense, args)
endSeq2SeqAttentionDecoderfunction (m::Seq2SeqAttentionDecoder)(X, state; return_attention_wts = false)
enc_outputs, hidden_state, enc_valid_lens = state
embeds = m.embedding(X) # num_embeds x num_steps x batch_size
outs = map(eachslice(embeds; dims = 2)) do x
query = hidden_state[end] # num_hiddens x batch_size
query = Flux.unsqueeze(query, dims = 2) # num_hiddens x 1 x batch_size (num_queries = 1)
# enc_outputs -> num_hiddens x num_steps x batch_size
context, attention_wts = m.attention(query, enc_outputs, enc_outputs, enc_valid_lens) # num_hiddens x 1 x batch_size
embs_and_context = vcat(Flux.unsqueeze(x, dims = 2), context) # (num_embeds + num_hiddem x 1 x batch_size)
out, hidden_state = m.rnn(embs_and_context, hidden_state)
out, attention_wts
end
outputs = first.(outs)
attention_wts = last.(outs)
outputs_cat = cat(outputs..., dims = 2)
out_dense = m.dense(outputs_cat)
if !return_attention_wts
return out_dense, (enc_outputs, hidden_state, enc_valid_lens)
else
return out_dense, (enc_outputs, hidden_state, enc_valid_lens), attention_wts
end
endIn the following, we test the implemented decoder with attention using a minibatch of four sequences, each of which are seven time steps long.
vocab_size, embed_size, num_hiddens, num_layers = 10, 8, 16, 2
batch_size, num_steps = 4, 7
encoder = Seq2SeqEncoder(vocab_size, embed_size, num_hiddens, 2)
decoder = Seq2SeqAttentionDecoder(vocab_size, embed_size, num_hiddens, 2)
X = ones(Int64, num_steps, batch_size)
state = d2lai.init_state(decoder, encoder(X, nothing), nothing)
output, state = decoder(X, state);
@assert size(output) == (vocab_size, num_steps, batch_size)
@assert size(state[1]) == (num_hiddens, num_steps, batch_size)Training ​
Now that we specified the new decoder we can proceed analogously to :numref:sec_seq2seq_training: specify the hyperparameters, instantiate a regular encoder and a decoder with attention, and train this model for machine translation.
data = d2lai.MTFraEng(128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
encoder = Seq2SeqEncoder(length(data.src_vocab), embed_size, num_hiddens, num_layers)
decoder = Seq2SeqAttentionDecoder(length(data.tgt_vocab), embed_size, num_hiddens, 1)
model = Seq2Seq(encoder, decoder, data.tgt_vocab["<pad>"]) |> Flux.f64
opt = Flux.Adam(0.01)
trainer = Trainer(model, data, opt; max_epochs = 30, gpu = true, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer);┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(256 => 256; bias=false) # 65_536 parameters
│ summary(x) = "256×128 CuArray{Float64, 2, CUDA.DeviceMemory}"
â”” @ Flux ~/.julia/packages/Flux/3711C/src/layers/stateless.jl:60 [ Info: Train Loss: 4.0412164, Val Loss: 4.9143953
[ Info: Train Loss: 2.9848623, Val Loss: 4.5563307
[ Info: Train Loss: 2.596702, Val Loss: 4.536007
[ Info: Train Loss: 2.3639255, Val Loss: 4.4321485
[ Info: Train Loss: 1.9325426, Val Loss: 4.532488
[ Info: Train Loss: 1.6510118, Val Loss: 4.667524
[ Info: Train Loss: 1.4277399, Val Loss: 4.7502885
[ Info: Train Loss: 1.2472821, Val Loss: 4.89551
[ Info: Train Loss: 1.1188064, Val Loss: 5.0064697
[ Info: Train Loss: 0.9380719, Val Loss: 4.6692905
[ Info: Train Loss: 0.86440516, Val Loss: 4.6082397
[ Info: Train Loss: 0.76707596, Val Loss: 4.9298
[ Info: Train Loss: 0.7151689, Val Loss: 4.6978455
[ Info: Train Loss: 0.6187289, Val Loss: 4.97536
[ Info: Train Loss: 0.61237884, Val Loss: 4.873481
[ Info: Train Loss: 0.55913705, Val Loss: 4.797181
[ Info: Train Loss: 0.5458334, Val Loss: 4.754832
[ Info: Train Loss: 0.5831154, Val Loss: 4.7641854
[ Info: Train Loss: 0.45402354, Val Loss: 4.937228
[ Info: Train Loss: 0.54235154, Val Loss: 5.2440014
[ Info: Train Loss: 0.47185406, Val Loss: 4.8063936
[ Info: Train Loss: 0.45345232, Val Loss: 4.969394
[ Info: Train Loss: 0.3592736, Val Loss: 5.0772567
[ Info: Train Loss: 0.42924955, Val Loss: 5.0097284
[ Info: Train Loss: 0.40380043, Val Loss: 5.061416
[ Info: Train Loss: 0.45879066, Val Loss: 4.876223
[ Info: Train Loss: 0.4700213, Val Loss: 5.086087
[ Info: Train Loss: 0.5087924, Val Loss: 5.028875
[ Info: Train Loss: 0.4338621, Val Loss: 5.279842
[ Info: Train Loss: 0.47916874, Val Loss: 5.3080935After the model is trained, we use it to translate a few English sentences into French and compute their BLEU scores.
engs = ["go .", "i lost .", "he's calm .", "i'm home ."]
fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."]
batch = d2lai.build(data, engs, fras)
preds, _ = d2lai.predict_step(m, batch, cpu, data.args.num_steps; save_attention_wts = true)
for (en, fr, p) in zip(engs, fras, eachcol(preds))
translation = []
for token in d2lai.to_tokens(data.tgt_vocab, p)
if token == "<eos>"
break
end
push!(translation, token)
end
bleu_score = d2lai.bleu(join(translation, " "), fr, 2)
println("$en => $translation", "bleu: $bleu_score")
endgo . => Any["va", "!"]bleu: 1.0
i lost . => Any["j'ai", "perdu", "."]bleu: 1.0
he's calm . => Any["il", "est", "rouge", "."]bleu: 0.6580370064762462
i'm home . => Any["je", "suis", "chez", "moi", "."]bleu: 1.0btch = d2lai.build(data, [engs[end]], [fras[end]])
_, dec_attention_weights = d2lai.predict_step(m, btch, cpu, data.args.num_steps; save_attention_wts = true)([8992; 17195; … ; 147; 33;;], Any[[[0.0036159370854119866; 0.7363849426439718; … ; 0.0; 0.0;;;]], [[0.0008942547430702864; 0.7476627557204462; … ; 0.0; 0.0;;;]], [[0.01639370280956843; 0.7377580120962014; … ; 0.0; 0.0;;;]], [[0.01444825486157988; 0.7661613230564178; … ; 0.0; 0.0;;;]], [[0.005963594327483387; 0.7960444846780865; … ; 0.0; 0.0;;;]], [[0.0032423717549742775; 0.7807559478159698; … ; 0.0; 0.0;;;]], [[0.0023065271552696887; 0.7580802172896017; … ; 0.0; 0.0;;;]], [[0.0012377716877186166; 0.7457765382198726; … ; 0.0; 0.0;;;]], [[0.003037206982038262; 0.7127447870479188; … ; 0.0; 0.0;;;]]])attention_weights = cat([step[1] for step in dec_attention_weights]..., dims = 3)
attention_weights = reshape(attention_weights, :, data.args.num_steps, 1, 1)9×9×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
0.00361594 0.000894255 0.0163937 … 0.00230653 0.00123777 0.00303721
0.736385 0.747663 0.737758 0.75808 0.745777 0.712745
0.161395 0.14444 0.15841 0.156805 0.16003 0.183647
0.0986039 0.107003 0.0874382 0.0828081 0.0929559 0.100571
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 … 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0d2lai.show_heatmaps(attention_weights[1:length(split(engs[end])), :, :, :], "Queries", "Keys")Summary ​
When predicting a token, if not all the input tokens are relevant, the RNN encoder–decoder with the Bahdanau attention mechanism selectively aggregates different parts of the input sequence. This is achieved by treating the state (context variable) as an output of additive attention pooling. In the RNN encoder–decoder, the Bahdanau attention mechanism treats the decoder hidden state at the previous time step as the query, and the encoder hidden states at all the time steps as both the keys and values.
Exercises ​
Replace GRU with LSTM in the experiment.
Modify the experiment to replace the additive attention scoring function with the scaled dot-product. How does it influence the training efficiency?