Supplement · Normalization in Deep Learning

Batch Normalization — Algorithm, Placement, and Multi-GPU

17 min read
By the end of this reading you will be able to:
  • Trace the full BatchNorm forward pass — computing batch mean and variance, normalizing, then applying learned scale and shift — and identify what changes between training and inference
  • Explain why BatchNorm maintains running statistics during training and uses them at inference, and state the consequence of failing to call model.eval() before inference
  • Compare pre-activation BatchNorm (original) with post-activation placement, state the empirical finding from pre-activation ResNets, and explain what the learned γ and β parameters recover
  • Explain Synchronized BatchNorm — why per-device BatchNorm fails for small per-GPU batches, and how SyncBN gathers statistics across devices to restore accurate normalization

The BatchNorm Algorithm

For a mini-batch B={z1,,zB}\mathcal{B} = \{z_1, \ldots, z_B\} of scalar pre-activations for one feature (the same operation is applied independently per feature across the batch):

Step 1 — Compute batch statistics: μB=1Bi=1BziσB2=1Bi=1B(ziμB)2\mu_\mathcal{B} = \frac{1}{B}\sum_{i=1}^B z_i \qquad \sigma^2_\mathcal{B} = \frac{1}{B}\sum_{i=1}^B (z_i - \mu_\mathcal{B})^2

Step 2 — Normalize: z^i=ziμBσB2+ϵ\hat{z}_i = \frac{z_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}}

Step 3 — Scale and shift: yi=γz^i+βy_i = \gamma\hat{z}_i + \beta

ϵ\epsilon is a numerical stability constant (typically 10510^{-5}). γ\gamma and β\beta are learned parameters — one scalar each per feature, updated by the optimizer like any other weight.

What γ and β Recover

Pure normalization forces every feature to have exactly zero mean and unit variance. This would prevent the network from representing, for example, a layer that should output large activations. The γ\gamma and β\beta parameters restore that capacity: the network can learn any mean and variance, but now in a parameterized way that the optimizer can control stably.

If the optimal transformation is the identity, the network can learn γ=σB\gamma = \sigma_\mathcal{B} and β=μB\beta = \mu_\mathcal{B}. BatchNorm never forces any particular scale — it only stabilizes training by starting from a normalized baseline.


Training vs. Inference

During training, batch statistics μB\mu_\mathcal{B} and σB2\sigma^2_\mathcal{B} are computed from the current mini-batch — they are stochastic (different for each batch). BN also maintains running statistics via an exponential moving average:

μrun(1m)μrun+mμB\mu_{\text{run}} \leftarrow (1-m)\mu_{\text{run}} + m\,\mu_\mathcal{B} σrun2(1m)σrun2+mσB2\sigma^2_{\text{run}} \leftarrow (1-m)\sigma^2_{\text{run}} + m\,\sigma^2_\mathcal{B}

where mm is the momentum (typically 0.1 in PyTorch — confusingly, this is the weight on the new batch, not the running average).

During inference, use the running statistics: y=γxμrunσrun2+ϵ+βy = \gamma \cdot \frac{x - \mu_{\text{run}}}{\sqrt{\sigma^2_{\text{run}} + \epsilon}} + \beta

This is deterministic — the same input always produces the same output.

Why not use batch statistics at inference? (1) The inference batch may be a single example, giving unreliable statistics. (2) The test distribution may differ from training, giving wrong statistics. (3) Determinism is required for production systems.

The model.eval() Bug

In PyTorch, model.eval() switches BatchNorm layers to use running statistics. model.train() switches back to batch statistics. Forgetting model.eval() before inference — or forgetting model.train() before resuming training — is among the most common bugs in PyTorch code:

  • Inference with train() mode: non-deterministic outputs that depend on what other examples are in the batch; outputs change if batch size changes
  • Training with eval() mode: gradients flow through fixed running statistics rather than the batch — effectively no BatchNorm regularization effect

Placement: Pre-Activation vs. Post-Activation

Original paper (Ioffe & Szegedy): BN is placed before the activation: xLinear/ConvBNϕnext layerx \to \text{Linear/Conv} \to \text{BN} \to \phi \to \text{next layer}

Pre-activation ResNets (He et al., 2016 — "Identity Mappings"): BN placed before both the activation and the convolution: xBNϕConvBNϕConv(+x)x \to \text{BN} \to \phi \to \text{Conv} \to \text{BN} \to \phi \to \text{Conv} \to (+x)

Key advantage: the residual shortcut is a clean identity path — no activation or BN sits on the skip connection. The gradient flows back through the shortcut without any transformation. Pre-activation ResNets train more stably at very large depths (ResNet-1001 in the original pre-activation paper).

Post-activation (common alternative): xLinear/ConvϕBNnext layerx \to \text{Linear/Conv} \to \phi \to \text{BN} \to \text{next layer}

In practice, the difference is often small. For new architectures: default to pre-activation (BN before activation); switch to post-activation if training diverges.


BatchNorm in Convolutional Layers

For a convolutional feature map with shape (B,C,H,W)(B, C, H, W) (batch, channels, height, width), BatchNorm normalizes over the (B,H,W)(B, H, W) dimensions per channel:

μc=1BHWb,i,jxb,c,i,j\mu_c = \frac{1}{B \cdot H \cdot W}\sum_{b,i,j} x_{b,c,i,j}

One μc\mu_c, σc2\sigma^2_c, γc\gamma_c, βc\beta_c per channel. This is sometimes called Spatial BatchNorm — the spatial dimensions are treated as extra batch dimensions.

A network with 64 channels has 64 sets of BN parameters, regardless of spatial resolution.


Limitations

Limitation Cause Fix
Small batch sizes Noisy batch statistics Use GroupNorm or LayerNorm
Recurrent networks Different sequence positions have different statistics Use LayerNorm
Online / single-sample inference Can't compute batch stats Use running stats (eval mode)
Very small per-GPU batches See SyncBN below Synchronized BatchNorm

Synchronized BatchNorm (SyncBN)

In multi-GPU training, each device typically processes a sub-batch. For small total batch sizes (e.g., 32 images across 8 GPUs → 4 images per GPU), per-device batch statistics are extremely noisy.

Synchronized BatchNorm gathers statistics across all devices before normalizing:

  1. Each GPU computes its local sum zi\sum z_i and zi2\sum z_i^2
  2. These are all-reduced across devices to compute the global μ\mu and σ2\sigma^2
  3. Each GPU normalizes using the global statistics

All-reduce adds a communication cost but ensures BN statistics are as accurate as if the entire batch were on one device. SyncBN is standard in object detection and segmentation, where per-image processing limits the per-GPU batch to 1–4 images.


PyTorch and TensorFlow

PyTorchnn.BatchNorm1d / nn.BatchNorm2d / nn.BatchNorm3d:

import torch
import torch.nn as nn

# BatchNorm1d: features or (batch, features, length)
bn1d = nn.BatchNorm1d(num_features=128)
x = torch.randn(32, 128)          # (batch, features)
out = bn1d(x)                     # normalized per feature across batch

# BatchNorm2d: convolutional maps (batch, channels, H, W)
bn2d = nn.BatchNorm2d(num_features=64)
x = torch.randn(8, 64, 32, 32)
out = bn2d(x)                     # statistics over (B, H, W) per channel

# Inspect stored state
print(bn2d.running_mean.shape)    # torch.Size([64]) — one per channel
print(bn2d.weight.shape)          # torch.Size([64]) — learned gamma
print(bn2d.bias.shape)            # torch.Size([64]) — learned beta

# CRITICAL: mode switching
model = nn.Sequential(nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU())

model.train()                     # batch stats, updates running stats
out_train = model(torch.randn(16, 128))

model.eval()                      # running stats, deterministic
out_eval  = model(torch.randn(1, 128))   # works at batch size 1

# Synchronized BatchNorm for multi-GPU (convert before wrapping with DDP)
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# then: model = nn.parallel.DistributedDataParallel(sync_model, device_ids=[local_rank])

TensorFlow / Keras:

import tensorflow as tf

bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
x  = tf.random.normal((32, 64))

out_train = bn(x, training=True)   # batch stats; updates moving_mean / moving_variance
out_infer = bn(x, training=False)  # uses stored moving averages

print(bn.moving_mean.shape)        # (64,)
print(bn.gamma.shape)              # (64,) — learned scale
print(bn.beta.shape)               # (64,) — learned shift

# training flag is propagated automatically when calling model(x, training=...)
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Activation('relu'),
])
model(x, training=True)            # train mode
model(x, training=False)           # inference mode