Attention: Part 1

Understanding attention in deep learning

Published: | Last updated:

Note: You can run (and edit) this code in the browser thanks to PyScript!

Typically, you'll see attention written in this form:

$$ \mathrm{Attention}(Q,K,V) = \mathrm{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V \label{eq:attention} \tag{1} $$ where
  • \( Q \) is the query matrix
  • \( K \) is the key matrix
  • \( V \) is the value matrix
  • \( d_k \) is the dimension of the keys and queries

I found this equation confusing. What are these matrices? Are they trainable weights? How are they related to the input?

These matrices are actually a function of the input \(X \). So for the query matrix, we would define:

$$ \underbrace{Q}_{n \times d_k} = \overbrace{X}^{n \times d_x} \underbrace{W_Q}_{d_x \times d_k} \label{eq:query} \tag{2} $$

where \( W_Q \) is a trainable matrix for the queries. This same approach applies to the key and value matrices. Let's break this down more by first understanding the input matrix \( X \).

What is the input \( X \) ?

Say we have some input \( X \in \mathbb{R}^{n \times d_x}\). This is a matrix with \(n \) tokens, where each token has \( d_x \) feature dimensions (or embedding size). A token could, for example, represent a word in a sentence or a vectorized patch in an image. The corresponding feature/embeddings are some numbers representing that token.


Okay! Let's assume we have this input matrix \( X \) now. This \( X \) could be the original encoded input or a learned representation from a previous step.

Now that we understand the \( X \) from Eq. \( \eqref{eq:query} \), let's talk about the weights \( W_Q \).

Weight matrix

This is where the "learning" is actually done. We'll focus on the query weight matrix \(W_Q \) now, but the same approach is done for the "keys" and "values".

The query weight matrix \(W_Q \) is of size \(d_x \times d_k \), where \(d_x\) is the number of features in the input \( X \), and \(d_k \) is a choice that you can make (which would indicate the new size of the feature vectors to learn). So the weight matrix would be something like:


Now that we understand the different components, we are ready to compute the query matrix \( Q \) as we defined in Eq. \( \eqref{eq:query} \).

Computing the Query matrix \( Q \)

We'll repeat Eq. \( \eqref{eq:query} \) here for clarity, where we perform matrix multiplication on \( X \) and \( W_K \):

$$ \underbrace{Q}_{n \times d_k} = \overbrace{X}^{n \times d_x} \underbrace{W_Q}_{d_x \times d_k} \label{eq:query2} \tag{2} $$

We now have the query matrix \( Q \) and can see the dimensions of the resulting \( Q \) is of \( n \times d_k \) where \(n \) is the number of tokens, and \( d_k \) is a dimension that we defined to represent our new feature/embedding space.

It's worth mentioning how matrix multiplication works. We take a row in \( X \) (which represents a token) and multiply each element with a column in the \( W \) and sum them. We do that for each row and column.

This is simplier to see with an example. Let's take our input \( X \) and our weight matrix \( W_Q \) that we defined earlier. We'll go through a simple example where we compute this by hand.


Nice! Now in practice, you would not compute this by hand, and instead use something like "nn.Linear()" in Pytorch. Or using numpy,


But this simple manual example let's us observe a few things.

  • We compute a new feature representation for each token using weights
  • These weights will be learned (not shown here)
  • We set how many new features we want by the number of columns in \( W \)
  • The tokens themselves are NOT mixed and know nothing about each other

So the next question you might ask is, what if I want to have tokens interact with each other? This is where attention comes in.

We will discuss this in Part 2, coming soon!