Neural Network Architectures in PyTorch
nn.Module; compare their performance and gradient norms on a long-sequence task (sequence length ≥ 200)
Lab Overview
Hands-on implementation of all major neural network architectures covered in this supplement — from MLPs through transformers. Each section builds from scratch using nn.Module, then validates against PyTorch built-ins and analyzes the key properties (gradient flow, parameter efficiency, attention patterns).
Sections
| Section | Topic | Key experiment |
|---|---|---|
| 1 | MLP depth and non-linearity | SVD rank collapse without activations |
| 2 | CNN and depthwise separable | Output-size formula; parameter counts |
| 3 | ResNet residual connections | Gradient magnitude vs. depth |
| 4 | Vanilla RNN vs. LSTM | Long-sequence accuracy and gradient norms |
| 5 | Attention mechanism | Softmax saturation without √d_k scaling |
| 6 | Mini transformer | Character LM; attention pattern visualization |
Section 1 — MLP: Depth, Width, and Non-linearity
Build MLP(d_in, hidden_dims, d_out, activation) where hidden_dims is a list. Verify parameter count matches the theoretical formula:
class MLP(nn.Module):
def __init__(self, d_in, hidden_dims, d_out, activation=nn.ReLU):
super().__init__()
dims = [d_in] + hidden_dims + [d_out]
layers = []
for i in range(len(dims) - 1):
layers.append(nn.Linear(dims[i], dims[i+1]))
if i < len(dims) - 2: # no activation on final layer
layers.append(activation())
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
Collapse experiment: build a 4-layer MLP without activations (activation=nn.Identity), run a batch through, and compute torch.linalg.matrix_rank of the output. Compare to a 4-layer MLP with ReLU — the linear-only network should have rank ≤ d_in regardless of hidden dimensions.
Section 2 — CNN and Depthwise Separable Convolutions
Implement ConvBlock(in_ch, out_ch, K, S, P) and DepthwiseSeparable(in_ch, out_ch, K). Verify the output-size formula on several configurations:
class DepthwiseSeparable(nn.Module):
def __init__(self, in_ch, out_ch, K=3):
super().__init__()
P = K // 2
self.dw = nn.Conv2d(in_ch, in_ch, K, padding=P, groups=in_ch)
self.pw = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
return self.pw(self.dw(x))
Parameter count comparison: for in_ch=64, out_ch=128, K=3, count parameters in nn.Conv2d(64, 128, 3) vs DepthwiseSeparable(64, 128) and verify the ~8× reduction.
Build a 5-block CNN with alternating conv and max-pool layers and train on CIFAR-10 for 10 epochs. Compare accuracy and training time against a DepthwiseSeparable variant.
Section 3 — ResNet Residual Connections and Gradient Flow
Implement basic and bottleneck blocks, then build a 20-layer plain CNN and 20-layer ResNet:
class BasicBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
return F.relu(out + x) # residual connection
Gradient flow experiment: after one backward pass, collect the .grad.norm() for the first conv layer in each block. Plot gradient norm vs. block index. The plain network's gradients should decay (or oscillate wildly); the ResNet's gradients should remain roughly constant across depth.
Implement projection shortcuts (nn.Conv2d(in_ch, out_ch, 1, stride=s)) for dimension-changing transitions and verify forward-pass correctness.
Section 4 — Vanilla RNN vs. LSTM
Implement both from scratch using only nn.Linear:
class VanillaRNN(nn.Module):
def __init__(self, d_in, d_h, d_out):
super().__init__()
self.Wh = nn.Linear(d_h, d_h, bias=False)
self.Wx = nn.Linear(d_in, d_h)
self.Wy = nn.Linear(d_h, d_out)
def forward(self, x): # x: (B, T, d_in)
h = torch.zeros(x.size(0), self.Wh.in_features, device=x.device)
for t in range(x.size(1)):
h = torch.tanh(self.Wh(h) + self.Wx(x[:, t]))
return self.Wy(h)
Long-sequence task: generate a synthetic dataset where the label depends on a token placed at position 0 but the sequence length is 200. Train both models and compare final accuracy. Then compare gradient norms at the first time step between the two architectures — LSTM's cell-state path should show significantly larger gradients.
Section 5 — Scaled Dot-Product Attention
Implement attention from scratch:
def scaled_dot_product_attention(Q, K, V, mask=None, scale=True):
d_k = Q.size(-1)
scores = Q @ K.transpose(-2, -1)
if scale:
scores = scores / d_k ** 0.5
if mask is not None:
scores = scores.masked_fill(mask, float('-inf'))
weights = torch.softmax(scores, dim=-1)
return weights @ V, weights
Saturation experiment: generate random Q and K with d_k ranging from 4 to 512. For each d_k, compute the entropy of the attention weight distribution with and without scaling. Without scaling, entropy should drop sharply (one weight dominates); with scaling, entropy should remain roughly constant.
Build a MultiHeadAttention(d_model, n_heads) module that projects Q/K/V, splits into heads, applies attention, and concatenates. Verify output shape and compare to nn.MultiheadAttention.
Section 6 — Mini Decoder-Only Transformer
Assemble a character-level language model:
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(d_model, n_heads)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model)
)
self.drop = nn.Dropout(dropout)
def forward(self, x, causal_mask):
x = x + self.drop(self.attn(self.norm1(x), causal_mask=causal_mask))
x = x + self.drop(self.ffn(self.norm2(x)))
return x
Sinusoidal positional encoding: implement PE(pos, 2i) = sin(pos / 10000^{2i/d_model}), verify periodicity and the 'shift by dot product' property.
Train on a text file (Shakespeare or similar) for 5 epochs with context length 256. After training, visualize the attention weight matrices for each head in the final layer — different heads should show qualitatively different patterns (local vs. global, syntactic vs. positional).