Supplement · Normalization in Deep Learning

Group Normalization, Instance Normalization, and Spatial Variants

14 min read
By the end of this reading you will be able to:
  • Explain why BatchNorm degrades with small per-GPU batch sizes and identify the exact failure mechanism — noisy batch statistics — that GroupNorm and InstanceNorm avoid
  • Describe GroupNorm's normalization axis — computing statistics over G groups of C/G channels within each example — and state the relationship to LayerNorm (G=1) and InstanceNorm (G=C)
  • Identify the domains where InstanceNorm is preferred over BatchNorm or GroupNorm — specifically style transfer and image generation — and explain why per-example per-channel statistics capture style information
  • Compare BatchNorm, LayerNorm, InstanceNorm, and GroupNorm across the axes of batch-size sensitivity, spatial sensitivity, and typical use domain, and state which to reach for first in object detection, language modeling, and style transfer

The Small-Batch Problem

BatchNorm's performance degrades sharply when the per-device batch size drops below roughly 16–32 examples. The cause is not the batch size per se — it is that the batch statistics μB\mu_\mathcal{B} and σB2\sigma^2_\mathcal{B} become noisy estimates of the true feature distribution.

Consider object detection. A standard configuration is 2 images per GPU × 8 GPUs. Each GPU's BN computes statistics over 2 examples × (H × W) spatial locations per channel. Even with SyncBN, 2 × H × W may be as few as 2 × 7 × 7 = 98 samples per channel — and the feature distributions of two images can differ substantially. The normalization step introduces variance that drowns the signal.

Wu & He (2018) showed that GroupNorm matches or outperforms BatchNorm for batch sizes below 8 on object detection (COCO), and degrades gracefully rather than catastrophically.


Instance Normalization

Instance normalization (Ulyanov et al., 2016) normalizes over the spatial dimensions of each example and each channel independently.

For a feature map xRB×C×H×Wx \in \mathbb{R}^{B \times C \times H \times W}, the statistics for example bb, channel cc are:

μb,c=1HWh,wxb,c,h,wσb,c2=1HWh,w(xb,c,h,wμb,c)2\mu_{b,c} = \frac{1}{HW}\sum_{h,w} x_{b,c,h,w} \qquad \sigma^2_{b,c} = \frac{1}{HW}\sum_{h,w}(x_{b,c,h,w} - \mu_{b,c})^2

x^b,c,h,w=xb,c,h,wμb,cσb,c2+ϵyb,c,h,w=γcx^b,c,h,w+βc\hat{x}_{b,c,h,w} = \frac{x_{b,c,h,w} - \mu_{b,c}}{\sqrt{\sigma^2_{b,c} + \epsilon}} \qquad y_{b,c,h,w} = \gamma_c\hat{x}_{b,c,h,w} + \beta_c

Every (b,c)(b, c) pair is normalized independently. This produces B×CB \times C separate mean-and-variance computations — one per example per channel.

Key property: the normalization is completely independent across the batch dimension and across the channel dimension. InstanceNorm is as batch-size-agnostic as LayerNorm.

Why InstanceNorm Works for Style Transfer

In neural style transfer and feed-forward style networks (Johnson et al., 2016), the "style" of an image can be characterized by the first and second moments of feature maps — the channel means and variances encode things like color palette and texture statistics.

By normalizing to zero mean and unit variance per channel per image, InstanceNorm removes the style of the content image from the intermediate features, leaving a style-agnostic representation that can then be re-styled by the learned γ\gamma and β\beta. This is why InstanceNorm is the default normalization in neural style transfer, image-to-image translation (pix2pix, CycleGAN), and many image generation architectures.

For classification tasks, this per-channel-per-instance normalization destroys discriminative information (two images of the same class but different contrast would be indistinguishable after InstanceNorm). InstanceNorm is not used in supervised classification.


Group Normalization

Group normalization (Wu & He, 2018) is a middle ground: it divides the CC channels into GG groups and normalizes over the spatial and channel dimensions within each group, per example.

For example bb, group gg (containing channels {c:cG/C=g}\{c : \lfloor cG/C \rfloor = g\}, of size C/GC/G each), and all spatial positions:

μb,g=1(C/G)HWcgh,wxb,c,h,w\mu_{b,g} = \frac{1}{(C/G) \cdot HW}\sum_{c \in g}\sum_{h,w} x_{b,c,h,w}

σb,g2=1(C/G)HWcgh,w(xb,c,h,wμb,g)2\sigma^2_{b,g} = \frac{1}{(C/G) \cdot HW}\sum_{c \in g}\sum_{h,w}\left(x_{b,c,h,w} - \mu_{b,g}\right)^2

x^b,c,h,w=xb,c,h,wμb,gσb,g2+ϵyb,c,h,w=γcx^b,c,h,w+βc\hat{x}_{b,c,h,w} = \frac{x_{b,c,h,w} - \mu_{b,g}}{\sqrt{\sigma^2_{b,g} + \epsilon}} \qquad y_{b,c,h,w} = \gamma_c\hat{x}_{b,c,h,w} + \beta_c

Learned parameters γc,βc\gamma_c, \beta_c remain per-channel (not per-group) — the grouping affects only the statistics computation, not the scale/shift.

GroupNorm as a Unifying Framework

GroupNorm subsumes InstanceNorm and LayerNorm as special cases:

GG value Equivalent to Statistics computed over
G=CG = C InstanceNorm Spatial dims only, per channel
G=1G = 1 LayerNorm All channels + spatial dims
1<G<C1 < G < C GroupNorm proper C/GC/G channels + spatial dims

BatchNorm does not fit this framework — it requires statistics across the batch dimension, which GroupNorm never touches.

Choosing GG

Wu & He recommend G=32G = 32 as a robust default for image models with large channel counts (e.g., 256 or 512 channels). For smaller channel counts, G=C/2G = C/2 or G=C/4G = C/4 is common. The sensitivity to GG is relatively low in practice — the key is that C/GC/G should be large enough (typically ≥ 8) to give stable statistics.


Visualization: What Each Method Normalizes

For a 4D feature tensor (B,C,H,W)(B, C, H, W), the colored region represents the values averaged together for one normalization computation:

  • BatchNorm: one rectangle per channel = all (B,H,W)(B, H, W) values for channel cc
  • LayerNorm: one rectangle per example = all (C,H,W)(C, H, W) values for example bb
  • InstanceNorm: one rectangle per (example, channel) = all (H,W)(H, W) values for (b,c)(b, c)
  • GroupNorm: one rectangle per (example, group) = all (C/G,H,W)(C/G, H, W) values for (b,g)(b, g)

This makes the batch-size dependence clear: only BatchNorm's normalization region spans the batch dimension.


Practical Comparison

Property BatchNorm LayerNorm InstanceNorm GroupNorm
Batch-size sensitive Yes (degrades < ~16) No No No
Spatial-resolution sensitive No (pools over H,W) Pools over all dims Yes (per H,W) Yes (per H,W)
Train/eval difference Yes No No No
Works at batch size 1 With running stats only Yes Yes Yes
Captures style (moments) No (across batch) No Yes Partially
Object detection / segmentation Weak (small batch) Rarely Rarely Standard
Language modeling / NLP Rarely Standard No Occasionally
Style transfer / image synthesis Weak Weak Standard Used
Supervised vision (large batch) Standard Occasionally No Alternative

Layer Norm in Non-Spatial Domains

For 1D sequence models (transformers, RNNs), the tensor is (B,T,d)(B, T, d) — batch, sequence length, features. In this context:

  • LayerNorm normalizes over the dd-dimensional feature vector at each token, independently. This is the transformer default.
  • BatchNorm would normalize over (B,T)(B, T) per feature — mixing tokens from different positions and different examples, which rarely makes semantic sense for sequences.

For 3D volumetric data (e.g., medical imaging, video) with shape (B,C,D,H,W)(B, C, D, H, W), GroupNorm extends naturally — statistics are computed over the group's channels and all spatial dimensions D,H,WD, H, W.


PyTorch and TensorFlow

PyTorchnn.GroupNorm and nn.InstanceNorm2d:

import torch
import torch.nn as nn

# GroupNorm: (num_groups, num_channels) — num_channels must be divisible by num_groups
gn = nn.GroupNorm(num_groups=32, num_channels=256)
x  = torch.randn(4, 256, 14, 14)   # (B, C, H, W)
out = gn(x)                         # stats over (C/G=8 channels, H, W) per example

# Special cases
ln_equiv = nn.GroupNorm(num_groups=1,   num_channels=64)  # G=1  → LayerNorm
in_equiv = nn.GroupNorm(num_groups=64,  num_channels=64)  # G=C  → InstanceNorm

# InstanceNorm2d: per-(example, channel) stats over spatial dims
in2d = nn.InstanceNorm2d(num_features=64, affine=True)
x    = torch.randn(2, 64, 32, 32)
out  = in2d(x)   # stats over (H, W) for each (b, c); no batch dependence

# Visualize what each method normalizes on the same (B=4, C=64, H=8, W=8) tensor
x = torch.randn(4, 64, 8, 8)
print('BN  mean:', nn.BatchNorm2d(64)(x).mean().item())          # ~0 across (B,H,W)
print('GN  mean:', nn.GroupNorm(32, 64)(x).mean().item())        # ~0 within groups
print('IN  mean:', nn.InstanceNorm2d(64)(x).mean().item())       # ~0 per (b,c)

# Typical GroupNorm usage in a ResNet block for detection
class GNResBlock(nn.Module):
    def __init__(self, channels: int, groups: int = 32):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.gn1   = nn.GroupNorm(groups, channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.gn2   = nn.GroupNorm(groups, channels)

    def forward(self, x):
        h = torch.relu(self.gn1(self.conv1(x)))
        return torch.relu(self.gn2(self.conv2(h)) + x)

TensorFlow / Keras:

import tensorflow as tf

# GroupNormalization — built-in from TF 2.11+ (channels-last format)
gn = tf.keras.layers.GroupNormalization(groups=32, axis=-1)
x  = tf.random.normal((4, 14, 14, 256))   # (B, H, W, C)
out = gn(x)

# InstanceNorm equivalent: LayerNormalization over spatial axes [1, 2]
in_norm = tf.keras.layers.LayerNormalization(axis=[1, 2])
x       = tf.random.normal((4, 32, 32, 64))
out     = in_norm(x)   # stats over (H, W) per (batch, channel)

# Or via tensorflow-addons:
# import tensorflow_addons as tfa
# in_norm = tfa.layers.InstanceNormalization(axis=-1)