Attention Mechanisms
- Explain the query-key-value abstraction: what each vector represents and how the dot product of a query with all keys produces an attention distribution over values
- Compute scaled dot-product attention Softmax(QK⊤/√d_k)V for small matrices and explain why the √d_k scaling factor is necessary
- Explain multi-head attention: why projecting into h separate subspaces and concatenating the results gives richer representations than a single large attention operation
- Distinguish self-attention, masked self-attention, and cross-attention by their source of queries, keys, and values, and identify where each appears in transformer encoder and decoder blocks
The Bottleneck Problem in Seq2Seq
The original seq2seq architecture (Sutskever et al., 2014) compressed an entire input sentence into a single fixed-size context vector — the final hidden state of the encoder. For long sequences, this vector becomes a bottleneck: it must encode all relevant information, but its capacity is fixed.
Bahdanau et al. (2015) proposed a solution: at each decoder step, instead of always reading from the same context vector, let the decoder attend to different encoder states depending on what it is currently generating. This was attention.
The Query-Key-Value Framework
Attention generalizes the retrieval metaphor: you have a query (what you're looking for), a set of keys (labels on stored information), and values (the stored information itself). The attention operation retrieves a weighted combination of values, where the weights reflect how well each key matches the query.
Formally, given:
- Query
- Keys — one row per position
- Values — one row per position
The dot products score how well the query matches each key. Softmax converts these to a probability distribution (attention weights). The output is a weighted sum of values.
Scaled Dot-Product Attention
For a batch of queries packed into matrix :
Why divide by ? The dot product . If and are independent with mean 0 and variance 1, the sum has variance . Without scaling, the dot products grow with , pushing softmax into regions with very small gradients (the "sharp softmax" problem — one position gets weight ≈ 1, all others ≈ 0, and gradients vanish). Dividing by restores unit variance.
Complexity: in time (all pairs of positions interact) and in memory for the attention matrix. This is the quadratic scaling that limits transformers to moderate sequence lengths.
Multi-Head Attention
A single attention operation computes one weighted combination of values. But different aspects of a relationship (syntactic vs. semantic, local vs. global) may need to be captured simultaneously.
Multi-head attention projects , , into different lower-dimensional subspaces, computes attention in each, then concatenates and projects:
where , , are learned projections, and is the output projection.
Typically — the total computation is similar to single-head attention with the same dimension, but distributed across parallel subspaces.
What each head learns: Analysis (Voita et al., 2019) shows different heads specialize — some track syntactic dependencies, some track coreference, some attend locally. The multi-head structure allows the model to jointly attend to information from different representation subspaces.
Self-Attention, Masked Self-Attention, and Cross-Attention
The attention mechanism is general — the source of , , and determines the variant:
Self-Attention
, , all come from the same sequence. Each position attends to all other positions in the same sequence.
- Used in transformer encoders (e.g., BERT)
- Every input position can attend to every other input position — bidirectional
- Captures global dependencies in a single operation, regardless of sequence length
Masked Self-Attention (Causal Attention)
Same as self-attention, but future positions are masked — position can only attend to positions :
Adding before softmax makes the corresponding attention weight exactly 0.
- Used in transformer decoders for autoregressive generation (GPT, LLaMA)
- Enforces that the model cannot "cheat" by looking at future tokens
Cross-Attention
comes from one sequence (the decoder); and come from a different sequence (the encoder output).
- Used in seq2seq transformers (T5, original transformer for translation)
- Allows the decoder to selectively read from the encoder's representations
- Mechanistically equivalent to the original Bahdanau attention
Attention vs. Recurrence
| Property | RNN | Self-Attention |
|---|---|---|
| Max path length (signal between positions , ) | ||
| Parallelizable | No (sequential) | Yes (all positions at once) |
| Complexity per layer | ||
| Long-range dependencies | Hard (LSTM mitigates) | Easy (direct connection) |
Self-attention connects any two positions with a path of length 1 — there are no intermediate steps for the signal to travel through. This is why transformers dramatically outperform LSTMs on long-range dependencies in language.
PyTorch and TensorFlow
PyTorch — scaled dot-product attention and nn.MultiheadAttention:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Scaled dot-product attention — from scratch
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (B, heads, T_q, d_k)
K: (B, heads, T_k, d_k)
V: (B, heads, T_k, d_v)
"""
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # (B, heads, T_q, T_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1) # attention weights
return weights @ V # (B, heads, T_q, d_v)
# PyTorch 2.0+ built-in (fused, memory-efficient)
Q = torch.randn(2, 8, 10, 64) # (B, heads, T, d_k)
K = torch.randn(2, 8, 10, 64)
V = torch.randn(2, 8, 10, 64)
out = F.scaled_dot_product_attention(Q, K, V) # uses FlashAttention when available
# nn.MultiheadAttention — handles projection, splitting into heads, and output projection
d_model = 512
mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, batch_first=True)
x = torch.randn(2, 10, d_model) # (B, T, d_model)
# Self-attention: query=key=value=x
out, attn_weights = mha(x, x, x) # out: (2,10,512) weights: (2,10,10)
# Cross-attention: query from decoder, key/value from encoder
enc_out = torch.randn(2, 20, d_model) # encoder output
dec_query = torch.randn(2, 5, d_model) # decoder query
out, _ = mha(dec_query, enc_out, enc_out)
# Causal mask for autoregressive (decoder-only) models
T = 10
mask = torch.tril(torch.ones(T, T)).bool() # lower-triangular
out, _ = mha(x, x, x, attn_mask=~mask) # mask future positions
TensorFlow / Keras:
import tensorflow as tf
# MultiHeadAttention layer
mha = tf.keras.layers.MultiHeadAttention(num_heads=8, key_dim=64) # key_dim = d_k per head
x = tf.random.normal((2, 10, 512)) # (B, T, d_model)
out = mha(query=x, key=x, value=x) # self-attention; out: (2, 10, 512)
# Cross-attention
enc = tf.random.normal((2, 20, 512))
out = mha(query=x, key=enc, value=enc)
# Causal mask
out = mha(query=x, key=x, value=x, use_causal_mask=True) # masks future tokens automatically