Queries, Keys, and Values
So far all the networks we have reviewed crucially relied on the input being of a well-defined size. For instance, the images in ImageNet are of size sec_seq2seq in the transformation of text [153]. In particular, for long sequences it becomes quite difficult to keep track of everything that has already been generated or even viewed by the network. Even explicit tracking heuristics such as proposed by Yang et al. [185] only offer limited benefit.
Compare this to databases. In their simplest form they are collections of keys (
We can design queries
that operate on ( , ) pairs in such a manner as to be valid regardless of the database size. The same query can receive different answers, according to the contents of the database.
The "code" being executed for operating on a large state space (the database) can be quite simple (e.g., exact match, approximate match, top-
). There is no need to compress or simplify the database to make the operations effective.
Clearly we would not have introduced a simple database here if it wasn't for the purpose of explaining deep learning. Indeed, this leads to one of the most exciting concepts introduced in deep learning in the past decade: the attention mechanism [183]. We will cover the specifics of its application to machine translation later. For now, simply consider the following: denote by
:eqlabel:eq_attention_pooling
where
The weights
are nonnegative. In this case the output of the attention mechanism is contained in the convex cone spanned by the values . The weights
form a convex combination, i.e., and for all . This is the most common setting in deep learning. Exactly one of the weights
is , while all others are . This is akin to a traditional database query. All weights are equal, i.e.,
for all . This amounts to averaging across the entire database, also called average pooling in deep learning.
A common strategy for ensuring that the weights sum up to
In particular, to ensure that the weights are also nonnegative, one can resort to exponentiation. This means that we can now pick any function
$
\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_j \exp(a(\mathbf{q}, \mathbf{k}_j))}. $ :eqlabel:eq_softmax_attention
This operation is readily available in all deep learning frameworks. It is differentiable and its gradient never vanishes, all of which are desirable properties in a model. Note though, the attention mechanism introduced above is not the only option. For instance, we can design a non-differentiable attention model that can be trained using reinforcement learning methods [186]. As one would expect, training such a model is quite complex. Consequently the bulk of modern attention research follows the framework outlined in Figure. We thus focus our exposition on this family of differentiable mechanisms.
The attention mechanism computes a linear combination over values
What is quite remarkable is that the actual "code" for executing on the set of keys and values, namely the query, can be quite concise, even though the space to operate on is significant. This is a desirable property for a network layer as it does not require too many parameters to learn. Just as convenient is the fact that attention can operate on arbitrarily large databases without the need to change the way the attention pooling operation is performed.
using Pkg; Pkg.activate("../../d2lai")
using LinearAlgebra
using d2lai
using Flux
using Downloads
using StatsBase
using Plots Activating project at `/workspace/workspace/d2l-julia/d2lai`Visualization
One of the benefits of the attention mechanism is that it can be quite intuitive, particularly when the weights are nonnegative and sum to
We thus define the show_heatmaps function. Note that it does not take a matrix (of attention weights) as its input but rather a tensor with four axes, allowing for an array of different queries and weights. Consequently the input matrices has the shape (number of rows for display, number of columns for display, number of queries, number of keys). This will come in handy later on when we want to visualize the workings that are to design Transformers.
function show_heatmaps(matrices, xlabel, ylabel, titles = nothing)
num_rows, num_cols = size(matrices)[end-1:end]
layout = (num_rows, num_cols)
titles = isnothing(titles) ? Iterators.repeated(nothing) : titles
heatmaps = map(eachslice(matrices, dims = (3,4)), titles) do matrix, title
isnothing(title) && return heatmap(matrix; xlabel, ylabel)
return heatmap(matrix; xlabel, ylabel, title)
end
plot(heatmaps...; layout)
endshow_heatmaps (generic function with 2 methods)attention_weights = reshape(Matrix(I, 10, 10) , 10, 10, 1 , 1)
show_heatmaps(attention_weights, "Queries", "Keys")Summary
The attention mechanism allows us to aggregate data from many (key, value) pairs. So far our discussion was quite abstract, simply describing a way to pool data. We have not explained yet where those mysterious queries, keys, and values might arise from. Some intuition might help here: for instance, in a regression setting, the query might correspond to the location where the regression should be carried out. The keys are the locations where past data was observed and the values are the (regression) values themselves. This is the so-called Nadaraya–Watson estimator [187], [188] that we will be studying in the next section.
By design, the attention mechanism provides a differentiable means of control by which a neural network can select elements from a set and to construct an associated weighted sum over representations.
Exercises
Suppose that you wanted to reimplement approximate (key, query) matches as used in classical databases, which attention function would you pick?
Suppose that the attention function is given by
and that for . Denote by the probability distribution over keys when using the softmax normalization in :eqref: eq_softmax_attention. Prove that. Design a differentiable search engine using the attention mechanism.
Review the design of the Squeeze and Excitation Networks [148] and interpret them through the lens of the attention mechanism.