Batch Normalization — Algorithm, Placement, and Multi-GPU
- Trace the full BatchNorm forward pass — computing batch mean and variance, normalizing, then applying learned scale and shift — and identify what changes between training and inference
- Explain why BatchNorm maintains running statistics during training and uses them at inference, and state the consequence of failing to call model.eval() before inference
- Compare pre-activation BatchNorm (original) with post-activation placement, state the empirical finding from pre-activation ResNets, and explain what the learned γ and β parameters recover
- Explain Synchronized BatchNorm — why per-device BatchNorm fails for small per-GPU batches, and how SyncBN gathers statistics across devices to restore accurate normalization
The BatchNorm Algorithm
For a mini-batch of scalar pre-activations for one feature (the same operation is applied independently per feature across the batch):
Step 1 — Compute batch statistics:
Step 2 — Normalize:
Step 3 — Scale and shift:
is a numerical stability constant (typically ). and are learned parameters — one scalar each per feature, updated by the optimizer like any other weight.
What γ and β Recover
Pure normalization forces every feature to have exactly zero mean and unit variance. This would prevent the network from representing, for example, a layer that should output large activations. The and parameters restore that capacity: the network can learn any mean and variance, but now in a parameterized way that the optimizer can control stably.
If the optimal transformation is the identity, the network can learn and . BatchNorm never forces any particular scale — it only stabilizes training by starting from a normalized baseline.
Training vs. Inference
During training, batch statistics and are computed from the current mini-batch — they are stochastic (different for each batch). BN also maintains running statistics via an exponential moving average:
where is the momentum (typically 0.1 in PyTorch — confusingly, this is the weight on the new batch, not the running average).
During inference, use the running statistics:
This is deterministic — the same input always produces the same output.
Why not use batch statistics at inference? (1) The inference batch may be a single example, giving unreliable statistics. (2) The test distribution may differ from training, giving wrong statistics. (3) Determinism is required for production systems.
The model.eval() Bug
In PyTorch, model.eval() switches BatchNorm layers to use running statistics. model.train() switches back to batch statistics. Forgetting model.eval() before inference — or forgetting model.train() before resuming training — is among the most common bugs in PyTorch code:
- Inference with
train()mode: non-deterministic outputs that depend on what other examples are in the batch; outputs change if batch size changes - Training with
eval()mode: gradients flow through fixed running statistics rather than the batch — effectively no BatchNorm regularization effect
Placement: Pre-Activation vs. Post-Activation
Original paper (Ioffe & Szegedy): BN is placed before the activation:
Pre-activation ResNets (He et al., 2016 — "Identity Mappings"): BN placed before both the activation and the convolution:
Key advantage: the residual shortcut is a clean identity path — no activation or BN sits on the skip connection. The gradient flows back through the shortcut without any transformation. Pre-activation ResNets train more stably at very large depths (ResNet-1001 in the original pre-activation paper).
Post-activation (common alternative):
In practice, the difference is often small. For new architectures: default to pre-activation (BN before activation); switch to post-activation if training diverges.
BatchNorm in Convolutional Layers
For a convolutional feature map with shape (batch, channels, height, width), BatchNorm normalizes over the dimensions per channel:
One , , , per channel. This is sometimes called Spatial BatchNorm — the spatial dimensions are treated as extra batch dimensions.
A network with 64 channels has 64 sets of BN parameters, regardless of spatial resolution.
Limitations
| Limitation | Cause | Fix |
|---|---|---|
| Small batch sizes | Noisy batch statistics | Use GroupNorm or LayerNorm |
| Recurrent networks | Different sequence positions have different statistics | Use LayerNorm |
| Online / single-sample inference | Can't compute batch stats | Use running stats (eval mode) |
| Very small per-GPU batches | See SyncBN below | Synchronized BatchNorm |
Synchronized BatchNorm (SyncBN)
In multi-GPU training, each device typically processes a sub-batch. For small total batch sizes (e.g., 32 images across 8 GPUs → 4 images per GPU), per-device batch statistics are extremely noisy.
Synchronized BatchNorm gathers statistics across all devices before normalizing:
- Each GPU computes its local sum and
- These are all-reduced across devices to compute the global and
- Each GPU normalizes using the global statistics
All-reduce adds a communication cost but ensures BN statistics are as accurate as if the entire batch were on one device. SyncBN is standard in object detection and segmentation, where per-image processing limits the per-GPU batch to 1–4 images.
PyTorch and TensorFlow
PyTorch — nn.BatchNorm1d / nn.BatchNorm2d / nn.BatchNorm3d:
import torch
import torch.nn as nn
# BatchNorm1d: features or (batch, features, length)
bn1d = nn.BatchNorm1d(num_features=128)
x = torch.randn(32, 128) # (batch, features)
out = bn1d(x) # normalized per feature across batch
# BatchNorm2d: convolutional maps (batch, channels, H, W)
bn2d = nn.BatchNorm2d(num_features=64)
x = torch.randn(8, 64, 32, 32)
out = bn2d(x) # statistics over (B, H, W) per channel
# Inspect stored state
print(bn2d.running_mean.shape) # torch.Size([64]) — one per channel
print(bn2d.weight.shape) # torch.Size([64]) — learned gamma
print(bn2d.bias.shape) # torch.Size([64]) — learned beta
# CRITICAL: mode switching
model = nn.Sequential(nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU())
model.train() # batch stats, updates running stats
out_train = model(torch.randn(16, 128))
model.eval() # running stats, deterministic
out_eval = model(torch.randn(1, 128)) # works at batch size 1
# Synchronized BatchNorm for multi-GPU (convert before wrapping with DDP)
sync_model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# then: model = nn.parallel.DistributedDataParallel(sync_model, device_ids=[local_rank])
TensorFlow / Keras:
import tensorflow as tf
bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
x = tf.random.normal((32, 64))
out_train = bn(x, training=True) # batch stats; updates moving_mean / moving_variance
out_infer = bn(x, training=False) # uses stored moving averages
print(bn.moving_mean.shape) # (64,)
print(bn.gamma.shape) # (64,) — learned scale
print(bn.beta.shape) # (64,) — learned shift
# training flag is propagated automatically when calling model(x, training=...)
model = tf.keras.Sequential([
tf.keras.layers.Dense(64),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
])
model(x, training=True) # train mode
model(x, training=False) # inference mode