Supplement · Neural Network Architectures

Attention Mechanisms

16 min read
By the end of this reading you will be able to:
  • Explain the query-key-value abstraction: what each vector represents and how the dot product of a query with all keys produces an attention distribution over values
  • Compute scaled dot-product attention Softmax(QK⊤/√d_k)V for small matrices and explain why the √d_k scaling factor is necessary
  • Explain multi-head attention: why projecting into h separate subspaces and concatenating the results gives richer representations than a single large attention operation
  • Distinguish self-attention, masked self-attention, and cross-attention by their source of queries, keys, and values, and identify where each appears in transformer encoder and decoder blocks

The Bottleneck Problem in Seq2Seq

The original seq2seq architecture (Sutskever et al., 2014) compressed an entire input sentence into a single fixed-size context vector — the final hidden state of the encoder. For long sequences, this vector becomes a bottleneck: it must encode all relevant information, but its capacity is fixed.

Bahdanau et al. (2015) proposed a solution: at each decoder step, instead of always reading from the same context vector, let the decoder attend to different encoder states depending on what it is currently generating. This was attention.


The Query-Key-Value Framework

Attention generalizes the retrieval metaphor: you have a query (what you're looking for), a set of keys (labels on stored information), and values (the stored information itself). The attention operation retrieves a weighted combination of values, where the weights reflect how well each key matches the query.

Formally, given:

  • Query qRdk\mathbf{q} \in \mathbb{R}^{d_k}
  • Keys KRn×dkK \in \mathbb{R}^{n \times d_k} — one row per position
  • Values VRn×dvV \in \mathbb{R}^{n \times d_v} — one row per position

Attention(q,K,V)=Softmax ⁣(qKdk)V\text{Attention}(\mathbf{q}, K, V) = \text{Softmax}\!\left(\frac{\mathbf{q} K^\top}{\sqrt{d_k}}\right) V

The dot products qKRn\mathbf{q} K^\top \in \mathbb{R}^n score how well the query matches each key. Softmax converts these to a probability distribution (attention weights). The output is a weighted sum of values.


Scaled Dot-Product Attention

For a batch of queries packed into matrix QRm×dkQ \in \mathbb{R}^{m \times d_k}:

Attention(Q,K,V)=Softmax ⁣(QKdk)V\text{Attention}(Q, K, V) = \text{Softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V

Why divide by dk\sqrt{d_k}? The dot product qk=i=1dkqiki\mathbf{q} \cdot \mathbf{k} = \sum_{i=1}^{d_k} q_i k_i. If qiq_i and kik_i are independent with mean 0 and variance 1, the sum has variance dkd_k. Without scaling, the dot products grow with dkd_k, pushing softmax into regions with very small gradients (the "sharp softmax" problem — one position gets weight ≈ 1, all others ≈ 0, and gradients vanish). Dividing by dk\sqrt{d_k} restores unit variance.

Complexity: O(n2dk)O(n^2 d_k) in time (all pairs of positions interact) and O(n2)O(n^2) in memory for the attention matrix. This is the quadratic scaling that limits transformers to moderate sequence lengths.


Multi-Head Attention

A single attention operation computes one weighted combination of values. But different aspects of a relationship (syntactic vs. semantic, local vs. global) may need to be captured simultaneously.

Multi-head attention projects QQ, KK, VV into hh different lower-dimensional subspaces, computes attention in each, then concatenates and projects:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, K W_i^K, V W_i^V) MultiHead(Q,K,V)=[head1;;headh]WO\text{MultiHead}(Q, K, V) = [\text{head}_1; \ldots; \text{head}_h]\, W^O

where WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}, WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}, WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v} are learned projections, and WORhdv×dmodelW^O \in \mathbb{R}^{h d_v \times d_{\text{model}}} is the output projection.

Typically dk=dv=dmodel/hd_k = d_v = d_{\text{model}} / h — the total computation is similar to single-head attention with the same dimension, but distributed across hh parallel subspaces.

What each head learns: Analysis (Voita et al., 2019) shows different heads specialize — some track syntactic dependencies, some track coreference, some attend locally. The multi-head structure allows the model to jointly attend to information from different representation subspaces.


Self-Attention, Masked Self-Attention, and Cross-Attention

The attention mechanism is general — the source of QQ, KK, and VV determines the variant:

Self-Attention

QQ, KK, VV all come from the same sequence. Each position attends to all other positions in the same sequence.

  • Used in transformer encoders (e.g., BERT)
  • Every input position can attend to every other input position — bidirectional
  • Captures global dependencies in a single operation, regardless of sequence length

Masked Self-Attention (Causal Attention)

Same as self-attention, but future positions are masked — position tt can only attend to positions 1,,t1, \ldots, t:

maskij={0jij>i\text{mask}_{ij} = \begin{cases} 0 & j \leq i \\ -\infty & j > i \end{cases}

Adding -\infty before softmax makes the corresponding attention weight exactly 0.

  • Used in transformer decoders for autoregressive generation (GPT, LLaMA)
  • Enforces that the model cannot "cheat" by looking at future tokens

Cross-Attention

QQ comes from one sequence (the decoder); KK and VV come from a different sequence (the encoder output).

  • Used in seq2seq transformers (T5, original transformer for translation)
  • Allows the decoder to selectively read from the encoder's representations
  • Mechanistically equivalent to the original Bahdanau attention

Attention vs. Recurrence

Property RNN Self-Attention
Max path length (signal between positions ii, jj) O(ij)O(|i-j|) O(1)O(1)
Parallelizable No (sequential) Yes (all positions at once)
Complexity per layer O(nd2)O(n \cdot d^2) O(n2d)O(n^2 \cdot d)
Long-range dependencies Hard (LSTM mitigates) Easy (direct connection)

Self-attention connects any two positions with a path of length 1 — there are no intermediate steps for the signal to travel through. This is why transformers dramatically outperform LSTMs on long-range dependencies in language.


PyTorch and TensorFlow

PyTorch — scaled dot-product attention and nn.MultiheadAttention:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Scaled dot-product attention — from scratch
def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (B, heads, T_q, d_k)
    K: (B, heads, T_k, d_k)
    V: (B, heads, T_k, d_v)
    """
    d_k    = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)   # (B, heads, T_q, T_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    weights = F.softmax(scores, dim=-1)                  # attention weights
    return weights @ V                                   # (B, heads, T_q, d_v)

# PyTorch 2.0+ built-in (fused, memory-efficient)
Q = torch.randn(2, 8, 10, 64)   # (B, heads, T, d_k)
K = torch.randn(2, 8, 10, 64)
V = torch.randn(2, 8, 10, 64)
out = F.scaled_dot_product_attention(Q, K, V)            # uses FlashAttention when available

# nn.MultiheadAttention — handles projection, splitting into heads, and output projection
d_model = 512
mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, batch_first=True)

x        = torch.randn(2, 10, d_model)   # (B, T, d_model)
# Self-attention: query=key=value=x
out, attn_weights = mha(x, x, x)         # out: (2,10,512)  weights: (2,10,10)

# Cross-attention: query from decoder, key/value from encoder
enc_out  = torch.randn(2, 20, d_model)   # encoder output
dec_query = torch.randn(2, 5, d_model)   # decoder query
out, _   = mha(dec_query, enc_out, enc_out)

# Causal mask for autoregressive (decoder-only) models
T    = 10
mask = torch.tril(torch.ones(T, T)).bool()   # lower-triangular
out, _ = mha(x, x, x, attn_mask=~mask)       # mask future positions

TensorFlow / Keras:

import tensorflow as tf

# MultiHeadAttention layer
mha = tf.keras.layers.MultiHeadAttention(num_heads=8, key_dim=64)  # key_dim = d_k per head

x   = tf.random.normal((2, 10, 512))   # (B, T, d_model)
out = mha(query=x, key=x, value=x)    # self-attention; out: (2, 10, 512)

# Cross-attention
enc = tf.random.normal((2, 20, 512))
out = mha(query=x, key=enc, value=enc)

# Causal mask
out = mha(query=x, key=x, value=x, use_causal_mask=True)  # masks future tokens automatically