Normalization Techniques in PyTorch
nn.LayerNorm and nn.RMSNorm; benchmark the FLOP reduction of RMSNorm on long sequences
nn.utils.weight_norm and nn.utils.spectral_norm to a small discriminator; verify the Lipschitz constraint of spectral norm empirically using random perturbation tests
Lab Overview
This lab builds every normalization technique from the readings into runnable PyTorch code. Each section follows the pattern: implement from scratch → verify against the PyTorch built-in → run a targeted experiment.
Sections
| Section | Topic | Key experiment |
|---|---|---|
| 1 | BatchNorm from scratch | Eval-mode bug; running stats vs batch stats |
| 2 | LayerNorm & RMSNorm | FLOP comparison on sequence length 2048 |
| 3 | Weight norm & spectral norm | Lipschitz verification via random perturbations |
| 4 | GroupNorm unification | BN vs GN accuracy across batch sizes 1–64 |
| 5 | AdaIN style transfer | Channel statistics before/after style injection |
| 6 | adaLN-Zero (DiT) | Identity-at-init verification |
Section 1 — BatchNorm from Scratch
Implement MyBatchNorm1d(num_features) with:
self.gamma,self.beta: learned parametersself.running_mean,self.running_var: buffers (not parameters)forward(x, training): uses batch stats when training, running stats at eval- Momentum update:
running_mean = 0.9 * running_mean + 0.1 * batch_mean
Verify numerical agreement with nn.BatchNorm1d. Then reproduce the eval-mode bug: run 100 forward passes with training=True on single examples, and show that predictions differ across calls purely due to batch composition.
Section 2 — LayerNorm and RMSNorm
Implement both as nn.Module subclasses operating on the last dimension. Key assertions:
# Post-LayerNorm: output has zero mean and unit variance per token
assert out.mean(dim=-1).abs().max() < 1e-5
assert (out.std(dim=-1) - 1).abs().max() < 1e-4
# Post-RMSNorm: RMS is 1 per token (mean is not necessarily zero)
assert ((out**2).mean(dim=-1).sqrt() - 1).abs().max() < 1e-4
Benchmark: generate a (batch=8, seq=2048, dim=4096) tensor and time 100 forward passes for LayerNorm vs RMSNorm. Measure the wall-clock speedup.
Section 3 — Weight Norm and Spectral Norm
Apply nn.utils.weight_norm to a linear layer and inspect the weight_g (magnitude) and weight_v (direction) parameters that replace the original weight.
For spectral norm: apply nn.utils.spectral_norm to every layer of a 4-layer MLP, train for 100 steps, and after each step compute torch.linalg.matrix_norm(layer.weight, ord=2). Assert it stays ≤ 1.0 ± 1e-4.
Lipschitz verification: for 1000 random pairs , assert (product of per-layer Lipschitz bounds).
Section 4 — GroupNorm Unification
Implement MyGroupNorm(G, C) that normalizes over groups of channels and spatial dimensions. Verify:
assert torch.allclose(MyGroupNorm(G=C, C=C)(x), nn.InstanceNorm2d(C, affine=False)(x))
assert torch.allclose(MyGroupNorm(G=1, C=C)(x), nn.LayerNorm([C, H, W])(x))
Train a toy detection head (ResNet-50 backbone, FPN neck) on a small detection task at batch sizes {1, 2, 4, 8, 16, 64}. Plot validation loss at epoch 5 for BN vs GroupNorm (G=32). Observe that BN degrades sharply below batch size ~8 while GN is stable throughout.
Section 5 — AdaIN Style Transfer
Implement adain(content, style) (from the quiz code-completion). Extract VGG-19 features at relu3_1 and relu4_1 for a content/style pair. Apply AdaIN and verify:
out = adain(content_feat, style_feat)
assert torch.allclose(out.mean(dim=[2,3]), style_feat.mean(dim=[2,3]), atol=1e-4)
assert torch.allclose(out.std(dim=[2,3]), style_feat.std(dim=[2,3]), atol=1e-4)
Reconstruct the stylized image by passing the AdaIN output through VGG's decoder. Visualise content image, style image, and result side by side.
Section 6 — adaLN-Zero Conditioning (DiT)
Implement a single DiT block with adaLN-Zero conditioning:
class DiTBlock(nn.Module):
def __init__(self, dim, cond_dim):
super().__init__()
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
self.attn = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
self.ff = nn.Sequential(nn.Linear(dim, 4*dim), nn.GELU(), nn.Linear(4*dim, dim))
# Zero-init: all 6 conditioning vectors start at zero
self.adaLN_proj = nn.Linear(cond_dim, 6 * dim)
nn.init.zeros_(self.adaLN_proj.weight)
nn.init.zeros_(self.adaLN_proj.bias)
Verify that at initialisation, block(x, c) returns x unchanged for any x and c (the identity property). Then train the block for 100 steps and observe that the conditioning vectors diverge from zero as the block learns structure.