Supplement · Normalization in Deep Learning

Normalization Techniques in PyTorch

Colab Notebook · ~50 min
Google Colab Notebook
Normalization Techniques in PyTorch
Python · ~50 min
Open in Colab
Lab Objectives
1
Implement BatchNorm1d from scratch with correct train/eval mode switching and running statistic updates; reproduce the eval-mode bug and measure its effect on prediction variance
2
Implement LayerNorm and RMSNorm from scratch; verify numerical parity against nn.LayerNorm and nn.RMSNorm; benchmark the FLOP reduction of RMSNorm on long sequences
3
Apply nn.utils.weight_norm and nn.utils.spectral_norm to a small discriminator; verify the Lipschitz constraint of spectral norm empirically using random perturbation tests
4
Implement GroupNorm from scratch, verify that G=C reduces to InstanceNorm and G=1 reduces to LayerNorm; train a small detection head under batch sizes ranging from 1 to 64 and compare BN vs GN stability
5
Implement AdaIN and apply it to VGG feature maps for neural style transfer; verify that the style image's per-channel statistics are exactly reproduced in the output
6
Implement adaLN-Zero conditioning as used in Diffusion Transformers (DiT): zero-initialize the projection layer and verify that each block starts as an identity function at initialization

Lab Overview

This lab builds every normalization technique from the readings into runnable PyTorch code. Each section follows the pattern: implement from scratch → verify against the PyTorch built-in → run a targeted experiment.

Sections

Section Topic Key experiment
1 BatchNorm from scratch Eval-mode bug; running stats vs batch stats
2 LayerNorm & RMSNorm FLOP comparison on sequence length 2048
3 Weight norm & spectral norm Lipschitz verification via random perturbations
4 GroupNorm unification BN vs GN accuracy across batch sizes 1–64
5 AdaIN style transfer Channel statistics before/after style injection
6 adaLN-Zero (DiT) Identity-at-init verification

Section 1 — BatchNorm from Scratch

Implement MyBatchNorm1d(num_features) with:

  • self.gamma, self.beta: learned parameters
  • self.running_mean, self.running_var: buffers (not parameters)
  • forward(x, training): uses batch stats when training, running stats at eval
  • Momentum update: running_mean = 0.9 * running_mean + 0.1 * batch_mean

Verify numerical agreement with nn.BatchNorm1d. Then reproduce the eval-mode bug: run 100 forward passes with training=True on single examples, and show that predictions differ across calls purely due to batch composition.

Section 2 — LayerNorm and RMSNorm

Implement both as nn.Module subclasses operating on the last dimension. Key assertions:

# Post-LayerNorm: output has zero mean and unit variance per token
assert out.mean(dim=-1).abs().max() < 1e-5
assert (out.std(dim=-1) - 1).abs().max() < 1e-4

# Post-RMSNorm: RMS is 1 per token (mean is not necessarily zero)
assert ((out**2).mean(dim=-1).sqrt() - 1).abs().max() < 1e-4

Benchmark: generate a (batch=8, seq=2048, dim=4096) tensor and time 100 forward passes for LayerNorm vs RMSNorm. Measure the wall-clock speedup.

Section 3 — Weight Norm and Spectral Norm

Apply nn.utils.weight_norm to a linear layer and inspect the weight_g (magnitude) and weight_v (direction) parameters that replace the original weight.

For spectral norm: apply nn.utils.spectral_norm to every layer of a 4-layer MLP, train for 100 steps, and after each step compute torch.linalg.matrix_norm(layer.weight, ord=2). Assert it stays ≤ 1.0 ± 1e-4.

Lipschitz verification: for 1000 random pairs (x1,x2)(x_1, x_2), assert f(x1)f(x2)x1x2\|f(x_1) - f(x_2)\| \leq \|x_1 - x_2\| (product of per-layer Lipschitz bounds).

Section 4 — GroupNorm Unification

Implement MyGroupNorm(G, C) that normalizes over groups of channels and spatial dimensions. Verify:

assert torch.allclose(MyGroupNorm(G=C, C=C)(x),  nn.InstanceNorm2d(C, affine=False)(x))
assert torch.allclose(MyGroupNorm(G=1, C=C)(x),  nn.LayerNorm([C, H, W])(x))

Train a toy detection head (ResNet-50 backbone, FPN neck) on a small detection task at batch sizes {1, 2, 4, 8, 16, 64}. Plot validation loss at epoch 5 for BN vs GroupNorm (G=32). Observe that BN degrades sharply below batch size ~8 while GN is stable throughout.

Section 5 — AdaIN Style Transfer

Implement adain(content, style) (from the quiz code-completion). Extract VGG-19 features at relu3_1 and relu4_1 for a content/style pair. Apply AdaIN and verify:

out = adain(content_feat, style_feat)
assert torch.allclose(out.mean(dim=[2,3]), style_feat.mean(dim=[2,3]), atol=1e-4)
assert torch.allclose(out.std(dim=[2,3]),  style_feat.std(dim=[2,3]),  atol=1e-4)

Reconstruct the stylized image by passing the AdaIN output through VGG's decoder. Visualise content image, style image, and result side by side.

Section 6 — adaLN-Zero Conditioning (DiT)

Implement a single DiT block with adaLN-Zero conditioning:

class DiTBlock(nn.Module):
    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
        self.attn   = nn.MultiheadAttention(dim, num_heads=8, batch_first=True)
        self.ff     = nn.Sequential(nn.Linear(dim, 4*dim), nn.GELU(), nn.Linear(4*dim, dim))
        # Zero-init: all 6 conditioning vectors start at zero
        self.adaLN_proj = nn.Linear(cond_dim, 6 * dim)
        nn.init.zeros_(self.adaLN_proj.weight)
        nn.init.zeros_(self.adaLN_proj.bias)

Verify that at initialisation, block(x, c) returns x unchanged for any x and c (the identity property). Then train the block for 100 steps and observe that the conditioning vectors diverge from zero as the block learns structure.