NLP & Advanced Activations
- Explain adaptive softmax cluster assignment and state why it reduces computation for large-vocabulary language modelling
- Trace the multi-head attention forward pass from Q, K, V inputs to the output projection and identify where softmax and scaling appear
- State the SwiGLU formula and explain why it has become the default FFN activation in modern transformer architectures
- Compare the FFN activation choices across GPT-2, LLaMA, and PaLM and relate each to a specific activation function family
LogSigmoid
LogSigmoid outputs — the log of a probability. It pairs directly with nn.NLLLoss for binary classification, in the same way that nn.LogSoftmax pairs with nn.NLLLoss for multi-class.
Numerically stable implementation: PyTorch uses:
The two-branch form avoids overflow: for large positive , computing is safe (it's tiny); for large negative , computing is safe (also tiny).
Relationship to BCEWithLogitsLoss: BCEWithLogitsLoss(logits, labels) ≡ where is LogSigmoid. The loss function and the activation are two sides of the same coin.
PyTorch:
x = torch.tensor([-3., -1., 0., 1., 3.])
print(nn.LogSigmoid()(x)) # tensor([-3.0486, -1.3133, -0.6931, -0.3133, -0.0486])
# Use with nn.NLLLoss for binary classification (like BCEWithLogitsLoss)
TensorFlow:
x = tf.constant([-3., -1., 0., 1., 3.])
print(tf.math.log_sigmoid(x)) # [-3.0486 -1.3133 -0.6931 -0.3133 -0.0486]
# Equivalent: -tf.math.softplus(-x)
AdaptiveLogSoftmaxWithLoss — Large Vocabulary NLP
For standard Softmax, computing the partition function over a vocabulary of size (often or more) costs at every forward pass. For a batch of 512 tokens, this dominates the compute.
Adaptive Softmax (Grave et al., 2017) reduces this to approximately using hierarchical clustering. Words are split into frequency clusters:
- Head cluster: The most frequent words (e.g., top 2,000) get computed with a full softmax — these are the words that appear in nearly every batch.
- Tail clusters: Rarer words are grouped into clusters with progressively smaller projection dimensions (controlled by
div_value). A cluster head token is added to the head vocabulary; if the head selects a cluster head, a second smaller softmax is applied within that cluster.
The probability is log-additive: .
PyTorch API:
adaptive_sm = nn.AdaptiveLogSoftmaxWithLoss(
in_features=512, # input embedding size
n_classes=50000, # vocabulary size
cutoffs=[2000, 10000], # cluster boundaries
div_value=4.0 # dimension reduction factor per cluster
)
output, loss = adaptive_sm(embeddings, targets) # combined forward+loss
log_probs = adaptive_sm.log_prob(embeddings) # inference
PyTorch:
adaptive_sm = nn.AdaptiveLogSoftmaxWithLoss(
in_features=512,
n_classes=50000,
cutoffs=[2000, 10000],
div_value=4.0
)
output, loss = adaptive_sm(embeddings, targets) # combined forward + loss
log_probs = adaptive_sm.log_prob(embeddings) # inference: log P(w)
TensorFlow:
# No equivalent built-in; approximate with sampled softmax for large vocabularies
loss = tf.nn.sampled_softmax_loss(
weights=embedding_matrix, # (vocab_size, embed_dim)
biases=bias,
labels=targets,
inputs=embeddings,
num_sampled=1000,
num_classes=50000
)
# For inference: tf.nn.log_softmax(embeddings @ embedding_matrix.T + bias)
MultiheadAttention
Multi-head attention is the core of the Transformer. Rather than computing one attention pattern over the full dimension, it runs attention heads in parallel, each projecting queries, keys, and values into a lower-dimensional subspace (). The heads learn different relationship patterns (positional, syntactic, semantic) and their outputs are concatenated and projected back.
Key design decisions:
- scaling: Without this, the dot product grows in magnitude with , pushing Softmax into saturation where gradients vanish.
- Causal masking: For decoder/generation models, an attention mask sets future positions to before Softmax, preventing the model from attending to future tokens.
batch_first=True: Modern PyTorch convention; ifFalse(old default), tensors are(seq, batch, dim)not(batch, seq, dim).
PyTorch API:
attn = nn.MultiheadAttention(
embed_dim=512,
num_heads=8,
dropout=0.1,
batch_first=True
)
output, weights = attn(query, key, value, attn_mask=causal_mask)
Memory complexity: in sequence length — quadratic attention is the scaling bottleneck that FlashAttention, sparse attention, and linear attention variants aim to address.
PyTorch:
attn = nn.MultiheadAttention(embed_dim=512, num_heads=8, dropout=0.1, batch_first=True)
# query, key, value: (batch, seq_len, embed_dim)
output, weights = attn(query, key, value, attn_mask=causal_mask)
TensorFlow:
attn = tf.keras.layers.MultiHeadAttention(num_heads=8, key_dim=64, dropout=0.1)
# query, value: (batch, seq_len, embed_dim); key optional (defaults to value)
output = attn(query, value, key=key, attention_mask=causal_mask)
SwiGLU — The Modern FFN Standard
While not a standalone PyTorch module, SwiGLU is the modern replacement for the transformer FFN's ReLU/GELU:
It is GLU but with SiLU replacing Sigmoid as the gate. Used in PaLM, LLaMA, Mistral, Gemma. To maintain parameter count, the hidden dimension is typically scaled to instead of .
PyTorch (SwiGLU FFN block):
class SwiGLUFFN(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
# Three projections: gate, value, output
self.w1 = nn.Linear(d_model, d_ff, bias=False) # gate
self.w2 = nn.Linear(d_ff, d_model, bias=False) # output
self.w3 = nn.Linear(d_model, d_ff, bias=False) # value
def forward(self, x):
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
TensorFlow (SwiGLU FFN block):
class SwiGLUFFN(tf.keras.layers.Layer):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = tf.keras.layers.Dense(d_ff, use_bias=False) # gate
self.w2 = tf.keras.layers.Dense(d_model, use_bias=False) # output
self.w3 = tf.keras.layers.Dense(d_ff, use_bias=False) # value
def call(self, x):
return self.w2(tf.nn.swish(self.w1(x)) * self.w3(x))
Putting It Together: Transformer FFN Variants
| FFN Type | Activation | Used In |
|---|---|---|
| Original | ReLU | Vaswani et al. 2017 |
| BERT | GELU | BERT, GPT-2, RoBERTa |
| GLU-based | GeGLU | T5v1.1, Flan-T5 |
| SwiGLU | SiLU gate | LLaMA, Mistral, PaLM |
| MixGLU | Mish gate | Research |
The trend is clear: smooth, self-gating activations have replaced ReLU in large-scale NLP, with GLU-variant FFNs becoming the de facto standard.