Supplement · Regularization

Normalization Layers

16 min read
By the end of this reading you will be able to:
  • Explain the internal covariate shift problem and how batch normalization addresses it by normalizing pre-activations within each mini-batch
  • Trace the BatchNorm algorithm through training and inference, explaining what changes between the two modes and why running statistics must be maintained
  • Distinguish Batch Norm, Layer Norm, Group Norm, Instance Norm, and RMS Norm by their normalization dimension and state which architecture each is standard in
  • Explain how batch normalization acts as an implicit regularizer — reducing the need for dropout — and state the conditions under which it breaks down

Internal Covariate Shift

As training progresses, the parameters of every layer change simultaneously. From the perspective of layer \ell, its input — the output of layer 1\ell-1 — keeps shifting as the earlier layers update. This is called internal covariate shift: the distribution of each layer's inputs changes throughout training.

Shifting inputs force each layer to continually re-adapt to the new distribution rather than just minimizing the loss. This slows training, requires careful initialization, and restricts the learning rate (large steps may destabilize earlier layers).

Batch normalization (Ioffe & Szegedy, 2015) addresses this by normalizing each layer's pre-activations to have zero mean and unit variance — keeping the distribution stable regardless of upstream parameter changes.


Batch Normalization

For a mini-batch of pre-activations {z1,,zB}\{z_1, \ldots, z_B\} (one scalar per example, for a single feature):

μB=1Bi=1BziσB2=1Bi=1B(ziμB)2\mu_\mathcal{B} = \frac{1}{B}\sum_{i=1}^B z_i \qquad \sigma^2_\mathcal{B} = \frac{1}{B}\sum_{i=1}^B (z_i - \mu_\mathcal{B})^2

z^i=ziμBσB2+ϵz~i=γz^i+β\hat{z}_i = \frac{z_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}} \qquad \tilde{z}_i = \gamma\hat{z}_i + \beta

  • μB\mu_\mathcal{B}, σB2\sigma^2_\mathcal{B}: batch mean and variance (not learned — computed from the batch)
  • ϵ\epsilon: small constant for numerical stability (typically 10510^{-5})
  • γ\gamma, β\beta: learned scale and shift parameters (one per feature)

The γ\gamma/β\beta parameters allow the network to undo the normalization if the task requires a non-zero mean or non-unit variance — they restore expressive power that pure normalization would remove.

Training vs. Inference

Training: compute μB\mu_\mathcal{B} and σB2\sigma^2_\mathcal{B} from the current mini-batch; also maintain running statistics (exponential moving average of batch statistics across iterations).

Inference: do not use batch statistics (the batch may be a single example, or statistics from a different distribution). Use the running statistics μrun\mu_{\text{run}}, σrun2\sigma^2_{\text{run}} computed during training. This is a separate mode — model.eval() in PyTorch switches to it.

Failing to switch to eval mode (or forgetting model.train() before resuming training) is one of the most common PyTorch bugs.

Placement

Original paper: before the activation (WxBNϕW\mathbf{x} \to \text{BN} \to \phi). In practice, after the activation often works equally well and is sometimes preferred. ResNets typically use pre-activation BN (BN → ReLU → Conv).

Implicit Regularization

Batch normalization acts as a regularizer in two ways:

  1. Noise injection: the batch mean and variance are computed from a random subset of examples, introducing stochasticity analogous to dropout noise
  2. Larger learning rates: BN stabilizes the gradient scale, allowing larger learning rates that implicitly regularize by keeping the optimizer in a broad, flat region of the loss landscape

In practice, adding BN to a network often lets you reduce or eliminate dropout entirely.

Limitations

  • Small batch sizes make the batch statistics noisy, destabilizing normalization — BN requires B32B \geq 32 (typically 64–256)
  • RNNs: applying BN across the time dimension is awkward; different positions in a sequence have different statistics
  • Online / streaming inference: cannot use batch statistics when processing examples one at a time

Layer Normalization

Layer normalization (Ba et al., 2016) normalizes across the feature dimension of a single example, independent of the batch:

x^i=xiμσ2+ϵ,μ=1dj=1dxj,σ2=1dj=1d(xjμ)2\hat{\mathbf{x}}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}, \quad \mu = \frac{1}{d}\sum_{j=1}^d x_j, \quad \sigma^2 = \frac{1}{d}\sum_{j=1}^d (x_j - \mu)^2

  • Train and inference are identical (no batch dependency, no running statistics needed)
  • Works with batch size 1
  • Standard in transformers (both encoder and decoder blocks) and RNNs/LSTMs

Group Normalization

Group normalization (Wu & He, 2018) divides the CC channels into GG groups and normalizes within each group:

  • Groups of 1: Instance Normalization (normalizes each channel independently) — standard in style transfer
  • Groups of CC (all channels in one group): Layer Normalization
  • G=32G = 32 or 1616: Group Normalization — stable across batch sizes, good for detection/segmentation where batch sizes are necessarily small

RMS Normalization

RMS Norm (Zhang & Sennrich, 2019) simplifies layer norm by removing the mean-centering step:

x^i=xiRMS(x)γi,RMS(x)=1dj=1dxj2+ϵ\hat{x}_i = \frac{x_i}{\text{RMS}(\mathbf{x})} \cdot \gamma_i, \quad \text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{d}\sum_{j=1}^d x_j^2 + \epsilon}

No mean subtraction, no β\beta shift parameter. Roughly 15–20% faster than LayerNorm. Used in LLaMA, Mistral, Gemma — it is the standard normalization for modern decoder-only LLMs.


Which Normalization Where

Norm Normalizes over Needs large batch? Standard in
BatchNorm Batch dimension, per feature Yes (\geq32) CNNs (ResNet, EfficientNet)
LayerNorm Feature dimension, per example No Transformers (encoder/decoder)
InstanceNorm Spatial, per channel per example No Style transfer
GroupNorm Groups of channels, per example No Detection/segmentation
RMSNorm Feature dimension (no centering) No Modern LLMs (LLaMA, Mistral)

PyTorch and TensorFlow

This reading is an overview. Full code is in the Normalization in Deep Learning supplement. Quick reference:

PyTorch:

import torch
import torch.nn as nn

# BatchNorm2d for CNNs — requires model.eval() at inference
bn = nn.BatchNorm2d(64)             # 64 channels

# LayerNorm for transformers — no train/eval switch needed
ln = nn.LayerNorm(512)              # normalize last dim

# GroupNorm for detection / small batches
gn = nn.GroupNorm(num_groups=32, num_channels=256)

# InstanceNorm for style transfer
in_ = nn.InstanceNorm2d(64, affine=True)

# RMSNorm — PyTorch 2.4+
rms = nn.RMSNorm(512)               # no mean centering; ~15% faster than LayerNorm

# CRITICAL: always switch BatchNorm mode
model = nn.Sequential(nn.Conv2d(3, 64, 3), nn.BatchNorm2d(64), nn.ReLU())
model.train()   # batch statistics during training
model.eval()    # running statistics at inference — forgetting this is a common bug

TensorFlow / Keras:

import tensorflow as tf

batch_norm  = tf.keras.layers.BatchNormalization()          # pass training=True/False
layer_norm  = tf.keras.layers.LayerNormalization(axis=-1)
group_norm  = tf.keras.layers.GroupNormalization(groups=32)
# training flag controls BatchNorm mode:
# model(x, training=True)  — batch stats
# model(x, training=False) — moving average stats