Attention Scoring Functions
In :numref:sec_attention-pooling, we used a number of different distance-based kernels, including a Gaussian kernel to model interactions between queries and keys. As it turns out, distance functions are slightly more expensive to compute than dot products. As such, with the softmax operation to ensure nonnegative attention weights, much of the work has gone into attention scoring functions eq_softmax_attention and Figure that are simpler to compute.
Computing the output of attention pooling as a weighted average of values, where weights are computed with the attention scoring function
using Pkg; Pkg.activate("../../d2lai")
using d2lai
using Flux
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN Activating project at `/workspace/workspace/d2l-julia/d2lai`Dot Product Attention
Let's review the attention function (without exponentiation) from the Gaussian kernel for a moment:
First, note that the final term depends on eq_softmax_attention, ensures that this term disappears entirely. Second, note that both batch and layer normalization (to be discussed later) lead to activations that have well-bounded, and often constant, norms
Last, we need to keep the order of magnitude of the arguments in the exponential function under control. Assume that all the elements of the query
:eqlabel:eq_dot_product_attention
Note that attention weights eq_softmax_attention by using the softmax operation:
:eqlabel:eq_attn-scoring-alpha
As it turns out, all popular attention mechanisms use the softmax, hence we will limit ourselves to that in the remainder of this chapter.
Convenience Functions
We need a few functions to make the attention mechanism efficient to deploy. This includes tools for dealing with strings of variable lengths (common for natural language processing) and tools for efficient evaluation on minibatches (batch matrix multiplication).
Masked Softmax Operation
One of the most popular applications of the attention mechanism is to sequence models. Hence we need to be able to deal with sequences of different lengths. In some cases, such sequences may end up in the same minibatch, necessitating padding with dummy tokens for shorter sequences (see :numref:sec_machine_translation for an example). These special tokens do not carry meaning. For instance, assume that we have the following three sentences:
Dive into Deep Learning
Learn to code <blank>
Hello world <blank> <blank>Since we do not want blanks in our attention model we simply need to limit
Let's implement it. Actually, the implementation cheats ever so slightly by setting the values of
function _sequence_mask_(X::AbstractArray, valid_len, value::T = 0) where {T}
n, q, b = size(X)
device = isa(X, CuArray) ? cu : identity
key_ids = reshape(device(collect(1:n)), n, 1, 1)
if eltype(valid_len) <: AbstractVector
# valid_len is Vector of Vectors
# Build a (n, q, b) mask with broadcasting
valid_mat = zeros(Int, q, b)
for j in 1:b
valid_mat[:, j] .= valid_len[j]
end
valid_mat = reshape(device(valid_mat), 1, q, b)
else
# valid_len is simple vector
valid_mat = reshape(device(collect(valid_len)), 1, 1, b)
end
mask = key_ids .<= valid_mat # shape: (n, q, b)
mask_f = T.(mask)
return X .* mask_f .+ value .* (1 .- mask_f)
end
function masked_softmax(X, valid_lens, value = 0.)
if isnothing(valid_lens)
return softmax(X, dims = 1)
else
X_ = _sequence_mask_(X, valid_lens, -1e6)
return softmax(X_, dims = 1)
end
endmasked_softmax (generic function with 2 methods)X = rand(4, 2, 2)
valid_lens = [2, 3]
masked_softmax(X, valid_lens)4×2×2 Array{Float64, 3}:
[:, :, 1] =
0.342791 0.525284
0.657209 0.474716
0.0 0.0
0.0 0.0
[:, :, 2] =
0.391309 0.331539
0.361107 0.283358
0.247584 0.385103
0.0 0.0If we need more fine-grained control to specify the valid length for each of the two vectors of every example, we simply use a two-dimensional tensor of valid lengths. This yields:
masked_softmax(X, [[1, 3], [2, 4]])4×2×2 Array{Float64, 3}:
[:, :, 1] =
1.0 0.384194
0.0 0.347209
0.0 0.268596
0.0 0.0
[:, :, 2] =
0.52007 0.273421
0.47993 0.233685
0.0 0.317595
0.0 0.175299Batch Matrix Multiplication
Another commonly used operation is to multiply batches of matrices by one another. This comes in handy when we have minibatches of queries, keys, and values. More specifically, assume that
Then the batch matrix multiplication (BMM) computes the elementwise product
:eqlabel:eq_batch-matrix-mul
Let's see this in action in a deep learning framework.
Q = randn(3, 4, 2)
K = randn(4, 6, 2)
M = batched_mul(Q, K)
@assert size(M) == (3, 6, 2)Scaled Dot Product Attention
Let's return to the dot product attention introduced in :eqref:eq_dot_product_attention. In general, it requires that both the query and the key have the same vector length, say
In practice, we often think of minibatches for efficiency, such as computing attention for
$
\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.$ :eqlabel:eq_softmax_QK_V
Note that when applying this to a minibatch, we need the batch matrix multiplication introduced in :eqref:eq_batch-matrix-mul. In the following implementation of the scaled dot product attention, we use dropout for model regularization.
struct DotProductAttention{D, A}
dropout::D
args::A
end
function (m::DotProductAttention)(queries, keys, values, valid_len = nothing)
# keys -> d x num_keys x batch_size
# queries -> d x num_queries x batch_size
d = size(queries, 1)
scores = batched_mul(batched_transpose(keys), queries) ./ sqrt(d)
# scores -> num_keys x num_queries x batch_size
attention_weights = masked_softmax(scores, valid_len)
# attention_weights -> num_keys x num_queries x batch_size
return batched_mul(values, m.dropout(attention_weights)), attention_weights
endTo illustrate how the DotProductAttention works, we use the same keys, values, and valid lengths from the earlier toy example for additive attention. For the purpose of our example we assume that we have a minibatch size of
batch_size = 2
num_key_val = 10
num_queries = 1
d = d_k = d_q = 2
d_v = 4
queries = randn(d, num_queries, batch_size)
keys = randn(d, num_key_val, batch_size)
vals = randn(d_v, num_key_val, batch_size)
valid_len = [2, 6]
dot_product_attn = DotProductAttention(Flux.Dropout(0.5), nothing)
scores, attn_weights = dot_product_attn(queries, keys, vals, valid_len)
@assert size(scores) == (4, 1, 2)d2lai.show_heatmaps(reshape(attn_weights, 10, 2, 1, 1), "Queries", "Keys")Additive Attention
When queries
:eqlabel:eq_additive-attn
where eq_additive-attn is that the query and key are concatenated and fed into an MLP with a single hidden layer. Using
struct AdditiveAttention{W, A, D} <: AbstractModel
weights::W
dropout::D
args::A
end
Flux.@layer AdditiveAttention trainable = (weights,)
function AdditiveAttention(k_d, q_d, v_d, num_hiddens::Int64, dropout::Float64; kw...)
W_k = Dense(k_d => num_hiddens; bias = false)
W_q = Dense(q_d => num_hiddens; bias = false)
W_v = Dense(num_hiddens => 1; bias = false)
AdditiveAttention((; W_k, W_q, W_v), Flux.Dropout(dropout), (;))
end
function (m::AdditiveAttention)(queries, keys, values, valid_lens)
queries = m.weights.W_q(queries) # num_hiddens x num_queries x batch_size
keys = m.weights.W_k(keys) # num_hiddens x num_keys x batch_size
features = Flux.unsqueeze(queries, 2) .+ Flux.unsqueeze(keys, 3) # num_hiddens x num_keys x num_queries x batch_size
features = tanh.(features)
scores = m.weights.W_v(features) # 1 x num_keys x num_queries x batch_size
scores = dropdims(scores, dims = 1) # num_keys x num_queries x batch_size
attention_weights = masked_softmax(scores, valid_lens) # num_keys x num_queries x batch_size
return batched_mul(values, m.dropout(attention_weights)), attention_weights
# num_hidden x num_queries x batch_size , num_keys x num_queries x batch_size
endLet's see how AdditiveAttention works. In our toy example we pick queries, keys and values of size DotProductAttention, except that now the queries are
d_q = 20
queries = randn(d_q, num_queries, batch_size)
m = AdditiveAttention(d_k, d_q, d_v, 8, 0.)
scores, attention_wts = m(queries, keys, vals, valid_len)
@assert size(scores) == (d_v, num_queries, batch_size)When reviewing the attention function we see a behavior that is qualitatively quite similar to that of DotProductAttention. That is, only terms within the chosen valid length
d2lai.show_heatmaps(reshape(attention_wts, 10, 2, 1, 1), "Queries", "Keys")Summary
In this section we introduced the two key attention scoring functions: dot product and additive attention. They are effective tools for aggregating across sequences of variable length. In particular, the dot product attention is the mainstay of modern Transformer architectures. When queries and keys are vectors of different lengths, we can use the additive attention scoring function instead. Optimizing these layers is one of the key areas of advance in recent years. For instance, NVIDIA's Transformer Library and Megatron [192] crucially rely on efficient variants of the attention mechanism. We will dive into this in quite a bit more detail as we review Transformers in later sections.
Exercises
Implement distance-based attention by modifying the
DotProductAttentioncode. Note that you only need the squared norms of the keysfor an efficient implementation. Modify the dot product attention to allow for queries and keys of different dimensionalities by employing a matrix to adjust dimensions.
How does the computational cost scale with the dimensionality of the keys, queries, values, and their number? What about the memory bandwidth requirements?