Multi-Head Attention ​
In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.
To this end, instead of performing a single attention pooling, queries, keys, and values can be transformed with
Multi-head attention, where multiple heads are concatenated then linearly transformed.
using Pkg; Pkg.activate("../../d2lai")
using LinearAlgebra
using d2lai
using Flux
using Downloads
using StatsBase
using Plots
using d2lai: DotProductAttention, Seq2Seq, AbstractEncoderDecoder, Seq2SeqEncoder, AdditiveAttention, StackedRNN, Seq2SeqAttentionDecoder Activating project at `/workspace/workspace/d2l-julia/d2lai`Model ​
Before providing the implementation of multi-head attention, let's formalize this model mathematically. Given a query
where sec_attention-scoring-functions. The multi-head attention output is another linear transformation via learnable parameters
Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed.
Implementation ​
In our implementation, we choose the scaled dot product attention for each head of the multi-head attention. To avoid significant growth of computational cost and parametrization cost, we set num_hiddens.
struct MultiHeadedAttention{W, AT, A} <: AbstractModel
weights::W
attention::AT
args::A
end
Flux.@layer MultiHeadedAttention trainable = (weights, )
function MultiHeadedAttention(num_hiddens::Int64, num_heads::Int64, dropout::AbstractFloat; bias=false)
W_q = Dense(num_hiddens, num_hiddens, bias = bias)
W_k = Dense(num_hiddens, num_hiddens, bias = bias)
W_v = Dense(num_hiddens, num_hiddens, bias = bias)
W_o = Dense(num_hiddens, num_hiddens, bias = bias)
attn = DotProductAttention(Dropout(dropout), (;))
MultiHeadedAttention((; W_q, W_k, W_v, W_o), attn, (; num_hiddens, num_heads, dropout))
end
function (m::MultiHeadedAttention)(queries, keys, values, valid_lens)
# queries -> q_d x num_queries x batch_size
# keys -> k_d x num_key_val x batch_size
# values -> v_d x num_key_val x batch_size
queries = m.weights.W_q(queries) # num_hiddens x num_queries x batch_size
queries = transpose_qkv(m, queries) # (num_hiddens / num_heads) x num_queries x (num_heads * batch_size)
keys = transpose_qkv(m, m.weights.W_k(keys)) # (num_hiddens / num_heads) x num_key_val x (num_heads * batch_size)
values = transpose_qkv(m, m.weights.W_v(values))# (num_hiddens / num_heads) x num_key_val x (num_heads * batch_size)
valid_lens = if !isnothing(valid_lens)
isa(valid_lens, AbstractVector) ? repeat(valid_lens, inner = m.args.num_heads) : repeat(valid_lens, inner = (m.args.num_heads, 1))
end
scores, attn_wts = m.attention(queries, keys, values, valid_lens) # (num_hiddens / num_heads) x num_queries x (num_heads * batch_size)
# attn_wts -> num_key_val x num_queries x batch_size
output_concat = transpose_output(m, scores) # num_hiddens x num_queries x batch_size
return m.weights.W_o(output_concat), attn_wts #
end
function transpose_qkv(m::MultiHeadedAttention, x)
# x -> num_hiddens x (num_queries or num_key_val) x batch_size
num_q_or_key_val = size(x, 2)
batch_size = size(x, 3)
x_ = reshape(x, :, m.args.num_heads, num_q_or_key_val, batch_size)
x_permuted = permutedims(x_, [1, 3, 2, 4]) # (num_hiddens / num_heads) x num_queries x num_heads x batch_size
return reshape(x_permuted, size(x_permuted)[1], size(x_permuted)[2], :) # (num_hiddens / num_heads) x num_queries x (num_heads * batch_size)
end
function transpose_output(m::MultiHeadedAttention, x)
x_ = reshape(x, size(x)[1], size(x)[2], m.args.num_heads, :)
x_permuted = permutedims(x_, [1, 3, 2, 4])
return reshape(x, :, size(x_permuted)[3], size(x_permuted)[4])
endtranspose_output (generic function with 1 method)num_hiddens, num_heads = 100, 5
attention = MultiHeadedAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = [3, 2]
X = ones((num_hiddens, num_queries, batch_size))
Y = ones((num_hiddens, num_kvpairs, batch_size))
context, attn_wt = attention(X, Y, Y, valid_lens)
@assert size(context) == (num_hiddens, num_queries, batch_size)┌ Warning: Layer with Float32 parameters got Float64 input.
│ The input will be converted, but any earlier layers may be very slow.
│ layer = Dense(100 => 100; bias=false) # 10_000 parameters
│ summary(x) = "100×8 Matrix{Float64}"
â”” @ Flux ~/.julia/packages/Flux/3711C/src/layers/stateless.jl:60function Seq2SeqMultiAttentionDecoder(vocab_size::Int, embed_size::Int, num_hiddens, num_layers, num_heads, dropout=0.)
embedding = Embedding(vocab_size => embed_size)
rnn = StackedRNN(embed_size + num_hiddens, num_hiddens, num_layers; rnn = Flux.LSTM)
attention = MultiHeadedAttention(num_hiddens, num_heads, dropout)
dense = Dense(num_hiddens, vocab_size)
args = (; vocab_size, embed_size, num_hiddens, num_layers)
Seq2SeqAttentionDecoder(attention, embedding, rnn, dense, args)
end
data = d2lai.MTFraEng(128)
embed_size, num_hiddens, num_layers, dropout = 256, 256, 2, 0.2
num_heads = 8
encoder = Seq2SeqEncoder(length(data.src_vocab), embed_size, num_hiddens, num_layers)
decoder = Seq2SeqMultiAttentionDecoder(length(data.tgt_vocab), embed_size, num_hiddens, 1, num_heads)
model = Seq2Seq(encoder, decoder, data.tgt_vocab["<pad>"])
opt = Flux.Adam(0.01)
trainer = Trainer(model, data, opt; max_epochs = 30, gpu = true, gradient_clip_val = 1.)
m, _ = d2lai.fit(trainer);Internal error: during type inference of
_pullback_generator(UInt64, LineNumberNode, Type, Type, Type, NTuple{1026, DataType})
Encountered stack overflow.
This might be caused by recursion over very long tuples or argument lists.
Internal error: during type inference of
_pullback(Zygote.Context{false}, typeof(Core.kwcall), NamedTuple{(:dims,), Tuple{Int64}}, typeof(Base.cat), Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2}, Array{Bool, 2})
Encountered stack overflow.
This might be caused by recursion over very long tuples or argument lists.
Internal error: during type inference of
_pullback_generator(UInt64, LineNumberNode, Type, Type, Type, NTuple{1026, DataType})
Encountered stack overflow.
This might be caused by recursion over very long tuples or argument lists. [ Info: Train Loss: 4.8276343, Val Loss: 5.4717245
[ Info: Train Loss: 4.3027673, Val Loss: 5.6018443
[ Info: Train Loss: 3.273195, Val Loss: 4.5906806
[ Info: Train Loss: 2.963611, Val Loss: 4.9624586
[ Info: Train Loss: 2.6126888, Val Loss: 4.4686055
[ Info: Train Loss: 2.471468, Val Loss: 5.0693927
[ Info: Train Loss: 2.0902474, Val Loss: 4.507521
[ Info: Train Loss: 1.9265594, Val Loss: 4.5847497
[ Info: Train Loss: 1.6047107, Val Loss: 5.0555325
[ Info: Train Loss: 1.4116989, Val Loss: 4.5631604
[ Info: Train Loss: 1.2966719, Val Loss: 4.7002335
[ Info: Train Loss: 1.1375505, Val Loss: 4.794996
[ Info: Train Loss: 0.92881125, Val Loss: 4.6975164
[ Info: Train Loss: 0.91435474, Val Loss: 4.6757035
[ Info: Train Loss: 0.7080981, Val Loss: 4.946134
[ Info: Train Loss: 0.7106923, Val Loss: 4.8681574
[ Info: Train Loss: 0.60006666, Val Loss: 4.701098
[ Info: Train Loss: 0.6312529, Val Loss: 5.1001134
[ Info: Train Loss: 0.55986136, Val Loss: 4.9545436
[ Info: Train Loss: 0.5404057, Val Loss: 4.772108
[ Info: Train Loss: 0.5702013, Val Loss: 4.9377155
[ Info: Train Loss: 0.5331384, Val Loss: 5.2816014
[ Info: Train Loss: 0.48318955, Val Loss: 4.9302464
[ Info: Train Loss: 0.49934065, Val Loss: 5.0310044
[ Info: Train Loss: 0.49979848, Val Loss: 5.26014
[ Info: Train Loss: 0.5294968, Val Loss: 5.193127
[ Info: Train Loss: 0.46126893, Val Loss: 5.001597
[ Info: Train Loss: 0.38484415, Val Loss: 5.1434236
[ Info: Train Loss: 0.4936912, Val Loss: 5.6913314
[ Info: Train Loss: 0.5257853, Val Loss: 5.145835engs = ["go .", "i lost .", "he's calm .", "i'm home ."]
fras = ["va !", "j'ai perdu .", "il est calme .", "je suis chez moi ."]
batch = d2lai.build(data, engs, fras)
preds, _ = d2lai.predict_step(m, batch, cpu, data.args.num_steps; save_attention_wts = true)
for (en, fr, p) in zip(engs, fras, eachcol(preds))
translation = []
for token in d2lai.to_tokens(data.tgt_vocab, p)
if token == "<eos>"
break
end
push!(translation, token)
end
bleu_score = d2lai.bleu(join(translation, " "), fr, 2)
println("$en => $translation", "bleu: $bleu_score")
endgo . => Any["va", "!"]bleu: 1.0
i lost . => Any["j'ai", "perdu", "."]bleu: 1.0
he's calm . => Any["j'en", "fais", "."]bleu: 0.0
i'm home . => Any["je", "suis", "chez", "moi", "."]bleu: 1.0btch = d2lai.build(data, [engs[end]], [fras[end]])
_, dec_attention_weights = d2lai.predict_step(m, btch, cpu, data.args.num_steps; save_attention_wts = true);attention_weights = cat([step[1] for step in dec_attention_weights]..., dims = 3)
attention_weights = reshape(attention_weights, :, data.args.num_steps, 1, 1)72×9×1×1 Array{Float64, 4}:
[:, :, 1, 1] =
0.0197803 0.50357 0.88467 … 0.673018 0.602622 0.658632
0.294253 0.111737 0.0296061 0.110218 0.128447 0.102955
0.501421 0.138298 0.029734 0.0741955 0.106384 0.0875043
0.184546 0.246395 0.05599 0.142569 0.162548 0.150908
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 … 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.867143 0.000183062 0.863314 0.167611 0.0753995 0.0480836
0.12625 0.00259461 0.0396935 … 0.161772 0.107583 0.0798361
0.00453269 0.241328 0.0382949 0.318684 0.337655 0.345838
0.00207432 0.755895 0.0586977 0.351933 0.479362 0.526242
⋮ ⋱
0.0 0.0 0.0 … 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.881675 0.994476 0.989049 0.924161 0.922438 0.932097
0.0969292 0.00491062 0.00987073 0.0437425 0.048495 0.0454433
0.0106992 0.0003172 0.000686157 … 0.018651 0.0157849 0.0124275
0.0106962 0.000295896 0.000394499 0.0134455 0.0132822 0.010032
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 … 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0