Adaptive and Conditional Normalization: AdaIN, SPADE, and DiT
- Explain Adaptive Instance Normalization (AdaIN) — how it transfers style by replacing the normalized content features' mean and variance with those of the style image — and identify what this implies about where style information is encoded
- Describe Conditional Batch Normalization — replacing fixed γ and β with vectors predicted from a conditioning signal — and state why this is more parameter-efficient than concatenating the condition to every feature map
- Explain SPADE's spatially-adaptive denormalization — how a segmentation mask is convolved to produce pixel-wise γ and β maps — and state why spatial variation in the affine parameters is necessary for semantic image synthesis
- Identify adaLN-Zero as the normalization mechanism in Diffusion Transformers (DiT), describe how the timestep and class embeddings are projected to produce γ and β, and explain the 'zero' initialization of the final projection
From Fixed to Predicted Parameters
Every normalization technique seen so far has fixed learned parameters: and are vectors updated by the optimizer, shared across all inputs at inference time. The affine step is unconditional — the same scale and shift is applied regardless of what you are generating.
Adaptive and conditional normalization breaks this assumption. The affine parameters and are instead predicted at runtime from a conditioning signal: a style image, a class label, a diffusion timestep, or a segmentation mask. This turns the normalization layer into a conditioning mechanism — a structured way to inject external information into the feature representations.
Adaptive Instance Normalization (AdaIN)
Huang & Belongie (2017) introduced AdaIN for arbitrary neural style transfer — the ability to apply any artistic style to any content image in a single forward pass, without retraining.
The operation is:
where:
- is the content feature map (shape )
- is the style feature map (same encoder, different image)
- and denote per-channel spatial means and standard deviations (the InstanceNorm statistics)
Step 1: Instance-normalize the content features — remove the content image's per-channel statistics (mean and variance).
Step 2: Re-scale and re-shift using the style image's statistics — inject the style image's mean and variance as the new affine parameters.
The result: the content structure (spatial arrangement of features) is preserved, but the statistics that encode style (color palette, texture density) are replaced with those of the style image.
Why Channel Statistics Encode Style
This is not arbitrary. Gram matrix-based style representations (Gatys et al., 2015) — the original method for neural style — compute feature covariances across spatial positions. The per-channel mean and variance are the diagonal of this covariance structure, capturing texture statistics. AdaIN approximates Gram matrix matching while being orders of magnitude faster — just two scalar operations per channel, not a covariance matrix.
Architecture: a shared encoder maps both content and style images to feature maps. AdaIN is applied at multiple scales in the decoder. There are no learned or — the statistics come entirely from the style input.
Conditional Batch Normalization
Conditional Batch Normalization (De Vries et al., 2017; Dumoulin et al., 2017) keeps the normalization step identical to standard BatchNorm but replaces the fixed and vectors with ones predicted from a conditioning signal:
where and for condition vector .
The network learns projection matrices that map the conditioning vector to per-channel scale and shift.
Parameter efficiency: for a network with layers and channels per layer, unconditional BN needs parameters per class (one and per class per layer). Conditional BN with a class embedding of dimension needs parameters total — shared across all classes. For large numbers of classes, this is far more efficient.
Applications:
- Visual question answering (FiLM): Perez et al. (2018) generalized this to Feature-wise Linear Modulation — predicting and from a language question to modulate image features. Called FiLM layers.
- Class-conditional image generation (BigGAN): class embedding projected to and at each ResNet block in the generator.
- Diffusion models: timestep and class conditioning projected to and at each block.
SPADE: Spatially-Adaptive Denormalization
SPADE (Park et al., 2019) extends conditional normalization to handle spatially varying conditioning — specifically for semantic image synthesis from segmentation maps.
The standard approach of injecting the segmentation mask as a channel concatenated to the input loses spatial information as the feature maps are downsampled through the network. SPADE injects the mask at every layer via spatially-varying affine parameters:
where is the segmentation mask and is a spatial map produced by two convolutions applied to the downsampled mask:
Same for . The mask is resized to match the current feature resolution, passed through a shared convolutional layer, then projected to -channel and maps.
Why spatial variation matters: a segmentation mask has different semantic labels at different spatial locations — sky, grass, building. If you want the generator to produce the right texture and color for sky pixels versus grass pixels, the normalization parameters must vary spatially to communicate this. A single global and vector cannot distinguish sky from grass at the feature map level.
SPADE (GauGAN) produced state-of-the-art results on semantic image synthesis and remains a standard component in conditional image generation architectures.
adaLN-Zero: Normalization in Diffusion Transformers
Diffusion Transformers (DiT, Peebles & Xie, 2023) apply the transformer architecture to image generation via diffusion. The conditioning signal consists of:
- The diffusion timestep (a scalar indicating how much noise is in the image)
- The class label (for class-conditional generation)
Both are embedded and summed: .
DiT tested several ways to inject this conditioning into the transformer blocks. The winner was adaLN-Zero (adaptive LayerNorm with zero initialization):
For each transformer block, a single linear layer projects to six vectors — scale and shift for the pre-attention LayerNorm, scale and shift for the pre-FFN LayerNorm, and scaling gates and applied to the residual branches:
The 'Zero' Initialization
The linear projection layer that outputs is initialized to zero weights and zero bias. At the start of training, this produces , , .
- , : the adaLN produces zero-mean unit-variance features (LayerNorm with zero affine correction)
- : both residual branches contribute zero — the entire block is an identity function at initialization
This is an extension of the zero-init residual trick: every DiT block starts as an identity, so the entire model begins as the identity function. Training can then incrementally add structure. Peebles & Xie showed this initialization substantially improves training stability and final FID.
Summary: The Normalization Design Space
The readings in this module have traced a progression:
| Technique | , source | Spatial variation | Primary domain |
|---|---|---|---|
| BatchNorm | Fixed, learned | None | Supervised vision (large batch) |
| LayerNorm / RMSNorm | Fixed, learned | None | Language models, transformers |
| InstanceNorm | Fixed, learned | Per-channel statistics vary | Style transfer |
| GroupNorm | Fixed, learned | Per-group statistics vary | Detection, small-batch vision |
| AdaIN | Style image statistics | Per-channel statistics vary | Arbitrary style transfer |
| Cond. BN / FiLM | Projected from condition | None (global γ/β) | Conditional generation, VQA |
| SPADE | Conv over mask, spatially | Full spatial γ/β maps | Semantic image synthesis |
| adaLN-Zero | Projected from + class | None (global γ/β) | Diffusion transformers |
The normalization layer has evolved from a training stabilizer into a conditioning interface — the primary mechanism by which external signals (style, class, timestep, text) are injected into generative models.
PyTorch and TensorFlow
PyTorch — AdaIN, FiLM, and adaLN-Zero:
import torch
import torch.nn as nn
import torch.nn.functional as F
# AdaIN: replace content statistics with style statistics
class AdaIN(nn.Module):
def __init__(self, eps: float = 1e-5):
super().__init__()
self.eps = eps
def forward(self, x_c, x_s):
# x_c / x_s: (B, C, H, W)
mu_c = x_c.mean(dim=[2, 3], keepdim=True)
sigma_c = x_c.std(dim=[2, 3], keepdim=True).add(self.eps)
mu_s = x_s.mean(dim=[2, 3], keepdim=True)
sigma_s = x_s.std(dim=[2, 3], keepdim=True).add(self.eps)
return sigma_s * (x_c - mu_c) / sigma_c + mu_s
# FiLM / Conditional BatchNorm: predict gamma and beta from a condition vector
class FiLM(nn.Module):
def __init__(self, cond_dim: int, num_channels: int):
super().__init__()
self.proj = nn.Linear(cond_dim, num_channels * 2)
def forward(self, x, cond):
# x: (B, C, H, W) cond: (B, cond_dim)
params = self.proj(cond) # (B, 2C)
gamma, beta = params.chunk(2, dim=-1) # (B, C)
gamma = gamma[:, :, None, None] # (B, C, 1, 1)
beta = beta[:, :, None, None]
return gamma * F.instance_norm(x) + beta
# adaLN-Zero: timestep+class conditioned LayerNorm with zero-initialized projection
class AdaLNZeroBlock(nn.Module):
def __init__(self, d_model: int, nhead: int, cond_dim: int):
super().__init__()
self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
self.attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model)
)
# Projects condition to (gamma1, beta1, alpha1, gamma2, beta2, alpha2)
self.adaLN_proj = nn.Linear(cond_dim, 6 * d_model)
nn.init.zeros_(self.adaLN_proj.weight) # zero-init: all blocks start as identity
nn.init.zeros_(self.adaLN_proj.bias)
def forward(self, x, cond):
# cond: embed(timestep) + embed(class) shape (B, cond_dim)
g1, b1, a1, g2, b2, a2 = self.adaLN_proj(cond).chunk(6, dim=-1) # (B, d_model)
# Pre-attention adaLN
x1 = (1 + g1.unsqueeze(1)) * self.norm1(x) + b1.unsqueeze(1)
x = x + a1.unsqueeze(1) * self.attn(x1, x1, x1)[0]
# Pre-FFN adaLN
x2 = (1 + g2.unsqueeze(1)) * self.norm2(x) + b2.unsqueeze(1)
x = x + a2.unsqueeze(1) * self.ffn(x2)
return x
TensorFlow / Keras:
import tensorflow as tf
# AdaIN
class AdaIN(tf.keras.layers.Layer):
def __init__(self, eps=1e-5, **kwargs):
super().__init__(**kwargs)
self.eps = eps
def call(self, x_c, x_s):
axes = [1, 2] # spatial dims for channels-last (B, H, W, C)
mu_c = tf.reduce_mean(x_c, axis=axes, keepdims=True)
sig_c = tf.math.reduce_std(x_c, axis=axes, keepdims=True) + self.eps
mu_s = tf.reduce_mean(x_s, axis=axes, keepdims=True)
sig_s = tf.math.reduce_std(x_s, axis=axes, keepdims=True) + self.eps
return sig_s * (x_c - mu_c) / sig_c + mu_s
# FiLM layer
class FiLM(tf.keras.layers.Layer):
def __init__(self, num_channels: int, **kwargs):
super().__init__(**kwargs)
self.proj = tf.keras.layers.Dense(num_channels * 2)
def call(self, x, cond):
# x: (B, H, W, C) cond: (B, cond_dim)
params = self.proj(cond) # (B, 2C)
gamma, beta = tf.split(params, 2, axis=-1) # (B, C)
gamma = gamma[:, tf.newaxis, tf.newaxis, :] # (B, 1, 1, C)
beta = beta[:, tf.newaxis, tf.newaxis, :]
mu = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
sigma = tf.math.reduce_std(x, axis=[1, 2], keepdims=True) + 1e-5
return gamma * (x - mu) / sigma + beta