Supplement · Normalization in Deep Learning

Adaptive and Conditional Normalization: AdaIN, SPADE, and DiT

14 min read
By the end of this reading you will be able to:
  • Explain Adaptive Instance Normalization (AdaIN) — how it transfers style by replacing the normalized content features' mean and variance with those of the style image — and identify what this implies about where style information is encoded
  • Describe Conditional Batch Normalization — replacing fixed γ and β with vectors predicted from a conditioning signal — and state why this is more parameter-efficient than concatenating the condition to every feature map
  • Explain SPADE's spatially-adaptive denormalization — how a segmentation mask is convolved to produce pixel-wise γ and β maps — and state why spatial variation in the affine parameters is necessary for semantic image synthesis
  • Identify adaLN-Zero as the normalization mechanism in Diffusion Transformers (DiT), describe how the timestep and class embeddings are projected to produce γ and β, and explain the 'zero' initialization of the final projection

From Fixed to Predicted Parameters

Every normalization technique seen so far has fixed learned parameters: γ\gamma and β\beta are vectors updated by the optimizer, shared across all inputs at inference time. The affine step is unconditional — the same scale and shift is applied regardless of what you are generating.

Adaptive and conditional normalization breaks this assumption. The affine parameters γ\gamma and β\beta are instead predicted at runtime from a conditioning signal: a style image, a class label, a diffusion timestep, or a segmentation mask. This turns the normalization layer into a conditioning mechanism — a structured way to inject external information into the feature representations.


Adaptive Instance Normalization (AdaIN)

Huang & Belongie (2017) introduced AdaIN for arbitrary neural style transfer — the ability to apply any artistic style to any content image in a single forward pass, without retraining.

The operation is:

AdaIN(xc,xs)=σ(xs)xcμ(xc)σ(xc)+μ(xs)\text{AdaIN}(\mathbf{x}_c, \mathbf{x}_s) = \sigma(\mathbf{x}_s) \cdot \frac{\mathbf{x}_c - \mu(\mathbf{x}_c)}{\sigma(\mathbf{x}_c)} + \mu(\mathbf{x}_s)

where:

  • xc\mathbf{x}_c is the content feature map (shape C×H×WC \times H \times W)
  • xs\mathbf{x}_s is the style feature map (same encoder, different image)
  • μ()\mu(\cdot) and σ()\sigma(\cdot) denote per-channel spatial means and standard deviations (the InstanceNorm statistics)

Step 1: Instance-normalize the content features — remove the content image's per-channel statistics (mean and variance).

Step 2: Re-scale and re-shift using the style image's statistics — inject the style image's mean and variance as the new affine parameters.

The result: the content structure (spatial arrangement of features) is preserved, but the statistics that encode style (color palette, texture density) are replaced with those of the style image.

Why Channel Statistics Encode Style

This is not arbitrary. Gram matrix-based style representations (Gatys et al., 2015) — the original method for neural style — compute feature covariances across spatial positions. The per-channel mean and variance are the diagonal of this covariance structure, capturing texture statistics. AdaIN approximates Gram matrix matching while being orders of magnitude faster — just two scalar operations per channel, not a C×CC \times C covariance matrix.

Architecture: a shared encoder maps both content and style images to feature maps. AdaIN is applied at multiple scales in the decoder. There are no learned γ\gamma or β\beta — the statistics come entirely from the style input.


Conditional Batch Normalization

Conditional Batch Normalization (De Vries et al., 2017; Dumoulin et al., 2017) keeps the normalization step identical to standard BatchNorm but replaces the fixed γ\gamma and β\beta vectors with ones predicted from a conditioning signal:

z^i=ziμBσB2+ϵyi=γ(c)z^i+β(c)\hat{z}_i = \frac{z_i - \mu_\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \epsilon}} \qquad y_i = \gamma(\mathbf{c}) \cdot \hat{z}_i + \beta(\mathbf{c})

where γ(c)=Wγc+bγ\gamma(\mathbf{c}) = W_\gamma \mathbf{c} + b_\gamma and β(c)=Wβc+bβ\beta(\mathbf{c}) = W_\beta \mathbf{c} + b_\beta for condition vector c\mathbf{c}.

The network learns projection matrices Wγ,WβW_\gamma, W_\beta that map the conditioning vector to per-channel scale and shift.

Parameter efficiency: for a network with LL layers and CC channels per layer, unconditional BN needs L×CL \times C parameters per class (one γ\gamma and β\beta per class per layer). Conditional BN with a class embedding of dimension dd needs L×(d×C+C)L \times (d \times C + C) parameters total — shared across all classes. For large numbers of classes, this is far more efficient.

Applications:

  • Visual question answering (FiLM): Perez et al. (2018) generalized this to Feature-wise Linear Modulation — predicting γ\gamma and β\beta from a language question to modulate image features. Called FiLM layers.
  • Class-conditional image generation (BigGAN): class embedding projected to γ\gamma and β\beta at each ResNet block in the generator.
  • Diffusion models: timestep and class conditioning projected to γ\gamma and β\beta at each block.

SPADE: Spatially-Adaptive Denormalization

SPADE (Park et al., 2019) extends conditional normalization to handle spatially varying conditioning — specifically for semantic image synthesis from segmentation maps.

The standard approach of injecting the segmentation mask as a channel concatenated to the input loses spatial information as the feature maps are downsampled through the network. SPADE injects the mask at every layer via spatially-varying affine parameters:

yb,c,h,w=γc,h,w(m)xb,c,h,wμb,cσb,c+βc,h,w(m)\mathbf{y}_{b,c,h,w} = \gamma_{c,h,w}(\mathbf{m}) \cdot \frac{x_{b,c,h,w} - \mu_{b,c}}{\sigma_{b,c}} + \beta_{c,h,w}(\mathbf{m})

where m\mathbf{m} is the segmentation mask and γc,h,w(m)\gamma_{c,h,w}(\mathbf{m}) is a spatial map produced by two convolutions applied to the downsampled mask:

γc,h,w(m)=Convγ(Convshared(resize(m)))c,h,w\gamma_{c,h,w}(\mathbf{m}) = \text{Conv}_\gamma(\text{Conv}_{\text{shared}}(\text{resize}(\mathbf{m})))_{c,h,w}

Same for βc,h,w\beta_{c,h,w}. The mask is resized to match the current feature resolution, passed through a shared convolutional layer, then projected to CC-channel γ\gamma and β\beta maps.

Why spatial variation matters: a segmentation mask has different semantic labels at different spatial locations — sky, grass, building. If you want the generator to produce the right texture and color for sky pixels versus grass pixels, the normalization parameters must vary spatially to communicate this. A single global γ\gamma and β\beta vector cannot distinguish sky from grass at the feature map level.

SPADE (GauGAN) produced state-of-the-art results on semantic image synthesis and remains a standard component in conditional image generation architectures.


adaLN-Zero: Normalization in Diffusion Transformers

Diffusion Transformers (DiT, Peebles & Xie, 2023) apply the transformer architecture to image generation via diffusion. The conditioning signal consists of:

  • The diffusion timestep tt (a scalar indicating how much noise is in the image)
  • The class label yy (for class-conditional generation)

Both are embedded and summed: c=embed(t)+embed(y)Rd\mathbf{c} = \text{embed}(t) + \text{embed}(y) \in \mathbb{R}^d.

DiT tested several ways to inject this conditioning into the transformer blocks. The winner was adaLN-Zero (adaptive LayerNorm with zero initialization):

[γ,β,α]=Linear(c)R6d[\gamma, \beta, \alpha] = \text{Linear}(\mathbf{c}) \in \mathbb{R}^{6d}

For each transformer block, a single linear layer projects c\mathbf{c} to six vectors — scale and shift for the pre-attention LayerNorm, scale and shift for the pre-FFN LayerNorm, and scaling gates αattn\alpha_{\text{attn}} and αFFN\alpha_{\text{FFN}} applied to the residual branches:

xx+αattnAttention(adaLN(x))\mathbf{x} \leftarrow \mathbf{x} + \alpha_{\text{attn}} \cdot \text{Attention}(\text{adaLN}(\mathbf{x})) xx+αFFNFFN(adaLN(x))\mathbf{x} \leftarrow \mathbf{x} + \alpha_{\text{FFN}} \cdot \text{FFN}(\text{adaLN}(\mathbf{x}))

The 'Zero' Initialization

The linear projection layer that outputs [γ,β,α][\gamma, \beta, \alpha] is initialized to zero weights and zero bias. At the start of training, this produces γ=0\gamma = \mathbf{0}, β=0\beta = \mathbf{0}, α=0\alpha = \mathbf{0}.

  • γ=0\gamma = \mathbf{0}, β=0\beta = \mathbf{0}: the adaLN produces zero-mean unit-variance features (LayerNorm with zero affine correction)
  • α=0\alpha = \mathbf{0}: both residual branches contribute zero — the entire block is an identity function at initialization

This is an extension of the zero-init residual trick: every DiT block starts as an identity, so the entire model begins as the identity function. Training can then incrementally add structure. Peebles & Xie showed this initialization substantially improves training stability and final FID.


Summary: The Normalization Design Space

The readings in this module have traced a progression:

Technique γ\gamma, β\beta source Spatial variation Primary domain
BatchNorm Fixed, learned None Supervised vision (large batch)
LayerNorm / RMSNorm Fixed, learned None Language models, transformers
InstanceNorm Fixed, learned Per-channel statistics vary Style transfer
GroupNorm Fixed, learned Per-group statistics vary Detection, small-batch vision
AdaIN Style image statistics Per-channel statistics vary Arbitrary style transfer
Cond. BN / FiLM Projected from condition None (global γ/β) Conditional generation, VQA
SPADE Conv over mask, spatially Full spatial γ/β maps Semantic image synthesis
adaLN-Zero Projected from tt + class None (global γ/β) Diffusion transformers

The normalization layer has evolved from a training stabilizer into a conditioning interface — the primary mechanism by which external signals (style, class, timestep, text) are injected into generative models.


PyTorch and TensorFlow

PyTorch — AdaIN, FiLM, and adaLN-Zero:

import torch
import torch.nn as nn
import torch.nn.functional as F

# AdaIN: replace content statistics with style statistics
class AdaIN(nn.Module):
    def __init__(self, eps: float = 1e-5):
        super().__init__()
        self.eps = eps

    def forward(self, x_c, x_s):
        # x_c / x_s: (B, C, H, W)
        mu_c    = x_c.mean(dim=[2, 3], keepdim=True)
        sigma_c = x_c.std(dim=[2, 3], keepdim=True).add(self.eps)
        mu_s    = x_s.mean(dim=[2, 3], keepdim=True)
        sigma_s = x_s.std(dim=[2, 3], keepdim=True).add(self.eps)
        return sigma_s * (x_c - mu_c) / sigma_c + mu_s

# FiLM / Conditional BatchNorm: predict gamma and beta from a condition vector
class FiLM(nn.Module):
    def __init__(self, cond_dim: int, num_channels: int):
        super().__init__()
        self.proj = nn.Linear(cond_dim, num_channels * 2)

    def forward(self, x, cond):
        # x: (B, C, H, W)  cond: (B, cond_dim)
        params        = self.proj(cond)                   # (B, 2C)
        gamma, beta   = params.chunk(2, dim=-1)           # (B, C)
        gamma = gamma[:, :, None, None]                   # (B, C, 1, 1)
        beta  = beta[:, :, None, None]
        return gamma * F.instance_norm(x) + beta

# adaLN-Zero: timestep+class conditioned LayerNorm with zero-initialized projection
class AdaLNZeroBlock(nn.Module):
    def __init__(self, d_model: int, nhead: int, cond_dim: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.norm2 = nn.LayerNorm(d_model, elementwise_affine=False)
        self.attn  = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn   = nn.Sequential(
            nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model)
        )
        # Projects condition to (gamma1, beta1, alpha1, gamma2, beta2, alpha2)
        self.adaLN_proj = nn.Linear(cond_dim, 6 * d_model)
        nn.init.zeros_(self.adaLN_proj.weight)   # zero-init: all blocks start as identity
        nn.init.zeros_(self.adaLN_proj.bias)

    def forward(self, x, cond):
        # cond: embed(timestep) + embed(class)  shape (B, cond_dim)
        g1, b1, a1, g2, b2, a2 = self.adaLN_proj(cond).chunk(6, dim=-1)  # (B, d_model)
        # Pre-attention adaLN
        x1 = (1 + g1.unsqueeze(1)) * self.norm1(x) + b1.unsqueeze(1)
        x  = x + a1.unsqueeze(1) * self.attn(x1, x1, x1)[0]
        # Pre-FFN adaLN
        x2 = (1 + g2.unsqueeze(1)) * self.norm2(x) + b2.unsqueeze(1)
        x  = x + a2.unsqueeze(1) * self.ffn(x2)
        return x

TensorFlow / Keras:

import tensorflow as tf

# AdaIN
class AdaIN(tf.keras.layers.Layer):
    def __init__(self, eps=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.eps = eps

    def call(self, x_c, x_s):
        axes   = [1, 2]   # spatial dims for channels-last (B, H, W, C)
        mu_c   = tf.reduce_mean(x_c, axis=axes, keepdims=True)
        sig_c  = tf.math.reduce_std(x_c, axis=axes, keepdims=True) + self.eps
        mu_s   = tf.reduce_mean(x_s, axis=axes, keepdims=True)
        sig_s  = tf.math.reduce_std(x_s, axis=axes, keepdims=True) + self.eps
        return sig_s * (x_c - mu_c) / sig_c + mu_s

# FiLM layer
class FiLM(tf.keras.layers.Layer):
    def __init__(self, num_channels: int, **kwargs):
        super().__init__(**kwargs)
        self.proj = tf.keras.layers.Dense(num_channels * 2)

    def call(self, x, cond):
        # x: (B, H, W, C)  cond: (B, cond_dim)
        params       = self.proj(cond)                      # (B, 2C)
        gamma, beta  = tf.split(params, 2, axis=-1)         # (B, C)
        gamma = gamma[:, tf.newaxis, tf.newaxis, :]         # (B, 1, 1, C)
        beta  = beta[:, tf.newaxis, tf.newaxis, :]
        mu    = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        sigma = tf.math.reduce_std(x, axis=[1, 2], keepdims=True) + 1e-5
        return gamma * (x - mu) / sigma + beta