Normalization Layers
- Explain the internal covariate shift problem and how batch normalization addresses it by normalizing pre-activations within each mini-batch
- Trace the BatchNorm algorithm through training and inference, explaining what changes between the two modes and why running statistics must be maintained
- Distinguish Batch Norm, Layer Norm, Group Norm, Instance Norm, and RMS Norm by their normalization dimension and state which architecture each is standard in
- Explain how batch normalization acts as an implicit regularizer — reducing the need for dropout — and state the conditions under which it breaks down
Internal Covariate Shift
As training progresses, the parameters of every layer change simultaneously. From the perspective of layer , its input — the output of layer — keeps shifting as the earlier layers update. This is called internal covariate shift: the distribution of each layer's inputs changes throughout training.
Shifting inputs force each layer to continually re-adapt to the new distribution rather than just minimizing the loss. This slows training, requires careful initialization, and restricts the learning rate (large steps may destabilize earlier layers).
Batch normalization (Ioffe & Szegedy, 2015) addresses this by normalizing each layer's pre-activations to have zero mean and unit variance — keeping the distribution stable regardless of upstream parameter changes.
Batch Normalization
For a mini-batch of pre-activations (one scalar per example, for a single feature):
- , : batch mean and variance (not learned — computed from the batch)
- : small constant for numerical stability (typically )
- , : learned scale and shift parameters (one per feature)
The / parameters allow the network to undo the normalization if the task requires a non-zero mean or non-unit variance — they restore expressive power that pure normalization would remove.
Training vs. Inference
Training: compute and from the current mini-batch; also maintain running statistics (exponential moving average of batch statistics across iterations).
Inference: do not use batch statistics (the batch may be a single example, or statistics from a different distribution). Use the running statistics , computed during training. This is a separate mode — model.eval() in PyTorch switches to it.
Failing to switch to eval mode (or forgetting model.train() before resuming training) is one of the most common PyTorch bugs.
Placement
Original paper: before the activation (). In practice, after the activation often works equally well and is sometimes preferred. ResNets typically use pre-activation BN (BN → ReLU → Conv).
Implicit Regularization
Batch normalization acts as a regularizer in two ways:
- Noise injection: the batch mean and variance are computed from a random subset of examples, introducing stochasticity analogous to dropout noise
- Larger learning rates: BN stabilizes the gradient scale, allowing larger learning rates that implicitly regularize by keeping the optimizer in a broad, flat region of the loss landscape
In practice, adding BN to a network often lets you reduce or eliminate dropout entirely.
Limitations
- Small batch sizes make the batch statistics noisy, destabilizing normalization — BN requires (typically 64–256)
- RNNs: applying BN across the time dimension is awkward; different positions in a sequence have different statistics
- Online / streaming inference: cannot use batch statistics when processing examples one at a time
Layer Normalization
Layer normalization (Ba et al., 2016) normalizes across the feature dimension of a single example, independent of the batch:
- Train and inference are identical (no batch dependency, no running statistics needed)
- Works with batch size 1
- Standard in transformers (both encoder and decoder blocks) and RNNs/LSTMs
Group Normalization
Group normalization (Wu & He, 2018) divides the channels into groups and normalizes within each group:
- Groups of 1: Instance Normalization (normalizes each channel independently) — standard in style transfer
- Groups of (all channels in one group): Layer Normalization
- or : Group Normalization — stable across batch sizes, good for detection/segmentation where batch sizes are necessarily small
RMS Normalization
RMS Norm (Zhang & Sennrich, 2019) simplifies layer norm by removing the mean-centering step:
No mean subtraction, no shift parameter. Roughly 15–20% faster than LayerNorm. Used in LLaMA, Mistral, Gemma — it is the standard normalization for modern decoder-only LLMs.
Which Normalization Where
| Norm | Normalizes over | Needs large batch? | Standard in |
|---|---|---|---|
| BatchNorm | Batch dimension, per feature | Yes (32) | CNNs (ResNet, EfficientNet) |
| LayerNorm | Feature dimension, per example | No | Transformers (encoder/decoder) |
| InstanceNorm | Spatial, per channel per example | No | Style transfer |
| GroupNorm | Groups of channels, per example | No | Detection/segmentation |
| RMSNorm | Feature dimension (no centering) | No | Modern LLMs (LLaMA, Mistral) |
PyTorch and TensorFlow
This reading is an overview. Full code is in the Normalization in Deep Learning supplement. Quick reference:
PyTorch:
import torch
import torch.nn as nn
# BatchNorm2d for CNNs — requires model.eval() at inference
bn = nn.BatchNorm2d(64) # 64 channels
# LayerNorm for transformers — no train/eval switch needed
ln = nn.LayerNorm(512) # normalize last dim
# GroupNorm for detection / small batches
gn = nn.GroupNorm(num_groups=32, num_channels=256)
# InstanceNorm for style transfer
in_ = nn.InstanceNorm2d(64, affine=True)
# RMSNorm — PyTorch 2.4+
rms = nn.RMSNorm(512) # no mean centering; ~15% faster than LayerNorm
# CRITICAL: always switch BatchNorm mode
model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64), nn.ReLU())
model.train() # batch statistics during training
model.eval() # running statistics at inference — forgetting this is a common bug
TensorFlow / Keras:
import tensorflow as tf
batch_norm = tf.keras.layers.BatchNormalization() # pass training=True/False
layer_norm = tf.keras.layers.LayerNormalization(axis=-1)
group_norm = tf.keras.layers.GroupNormalization(groups=32)
# training flag controls BatchNorm mode:
# model(x, training=True) — batch stats
# model(x, training=False) — moving average stats