Self-Attention and Positional Encoding
In deep learning, we often use CNNs or RNNs to encode sequences. Now with attention mechanisms in mind, imagine feeding a sequence of tokens into an attention mechanism such that at every step, each token has its own query, keys, and values. Here, when computing the value of a token's representation at the next layer, the token can attend (via its query vector) to any other's token (matching based on their key vectors). Using the full set of query-key compatibility scores, we can compute, for each token, a representation by building the appropriate weighted sum over the other tokens. Because every token is attending to each other token (unlike the case where decoder steps attend to encoder steps), such architectures are typically described as self-attention models [196], [142], and elsewhere described as intra-attention model [197], [198], [199]. In this section, we will discuss sequence encoding using self-attention, including using additional information for the sequence order.
using Pkg; Pkg.activate("../../d2lai")
using LinearAlgebra
using d2lai
using Flux
using Downloads
using StatsBase
using Plots
using CUDA, cuDNN Activating project at `/workspace/workspace/d2l-julia/d2lai`Self-Attention
Given a sequence of input tokens
according to the definition of attention pooling in :eqref:eq_attention_pooling. Using multi-head attention, the following code snippet computes the self-attention of a tensor with shape (batch size, number of time steps or sequence length in tokens,
num_hiddens, num_heads = 100, 5
attention = d2lai.MultiHeadedAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, [3, 2]
X = ones(num_hiddens, num_queries, batch_size)
out, attention_weights = attention(X, X, X, valid_lens)
@assert size(out) == (num_hiddens, num_queries, batch_size)Comparing CNNs, RNNs, and Self-Attention
Let's compare architectures for mapping a sequence of
Comparing CNN (padding tokens are omitted), RNN, and self-attention architectures.
Let's regard any text sequence as a "one-dimensional image". Similarly, one-dimensional CNNs can process local features such as
When updating the hidden state of RNNs, multiplication of the
In self-attention, the queries, keys, and values are all eq_softmax_QK_V, where an
All in all, both CNNs and self-attention enjoy parallel computation and self-attention has the shortest maximum path length. However, the quadratic computational complexity with respect to the sequence length makes self-attention prohibitively slow for very long sequences.
Positional Encoding
Unlike RNNs, which recurrently process tokens of a sequence one-by-one, self-attention ditches sequential operations in favor of parallel computation. Note that self-attention by itself does not preserve the order of the sequence. What do we do if it really matters that the model knows in which order the input sequence arrived?
The dominant approach for preserving information about the order of tokens is to represent this to the model as an additional input associated with each token. These inputs are called positional encodings, and they can either be learned or fixed a priori. We now describe a simple scheme for fixed positional encodings based on sine and cosine functions [142].
Suppose that the input representation
:eqlabel:eq_positional-encoding-def
At first glance, this trigonometric function design looks weird. Before we give explanations of this design, let's first implement it in the following PositionalEncoding class.
struct PositionalEncoding{P,D,A} <: AbstractModel
dropout::D
P::P
args::A
end
function PositionalEncoding(num_hiddens::Int, dropout::AbstractFloat, max_len = 1000)
dropout = Dropout(dropout)
P = zeros(num_hiddens, max_len, 1)
X = reshape(collect(1:max_len), 1, :) ./ 10000 .^ ((0:2:num_hiddens-1) / num_hiddens)
P[1:2:end, :, :] .= sin.(X)
P[2:2:end, :, :] .= cos.(X)
PositionalEncoding(dropout, P, (; num_hiddens, dropout, max_len))
end
function (pos::PositionalEncoding)(x)
device = isa(x, CuArray) ? gpu : cpu
P_device = pos.P |> device
pos.dropout(x .+ P_device[:, 1:size(x, 2), :])
endIn the positional embedding matrix
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0.3)
X = pos_encoding(zeros(encoding_dim, num_steps, 1))
P = pos_encoding.P[:, 1:size(X, 2), :]
plot(P[6:10, :, 1]', label = reshape(["col $i" for i in 6:10], 1, :), xlabel = "Row Position")Absolute Positional Information
To see how the monotonically decreased frequency along the encoding dimension relates to absolute positional information, let's print out the binary representations of
for i in 0:7
string.(digits(i, base=2, pad = 3)) |> join |> println
end000
100
010
110
001
101
011
111In binary representations, a higher bit has a lower frequency than a lower bit. Similarly, as demonstrated in the heat map below, the positional encoding decreases frequencies along the encoding dimension by using trigonometric functions. Since the outputs are float numbers, such continuous representations are more space-efficient than binary representations.
P_ = Flux.unsqueeze(Flux.unsqueeze(P[:, :, 1], 3), 3)
d2lai.show_heatmaps(P_, "Column (encoding dimension)", "Row (position)")Relative Positional Information
Besides capturing absolute positional information, the above positional encoding also allows a model to easily learn to attend by relative positions. This is because for any fixed position offset
This projection can be explained mathematically. Denoting eq_positional-encoding-def can be linearly projected to
where the
Summary
In self-attention, the queries, keys, and values all come from the same place. Both CNNs and self-attention enjoy parallel computation and self-attention has the shortest maximum path length. However, the quadratic computational complexity with respect to the sequence length makes self-attention prohibitively slow for very long sequences. To use the sequence order information, we can inject absolute or relative positional information by adding positional encoding to the input representations.
Exercises
Suppose that we design a deep architecture to represent a sequence by stacking self-attention layers with positional encoding. What could the possible issues be?
Can you design a learnable positional encoding method?
Can we assign different learned embeddings according to different offsets between queries and keys that are compared in self-attention? Hint: you may refer to relative position embeddings [200], [201].