Supplement · Normalization in Deep Learning

Weight Normalization and Spectral Normalization

15 min read
By the end of this reading you will be able to:
  • Describe the weight normalization reparameterization w = g·v/‖v‖ and explain how decoupling magnitude from direction simplifies the optimization geometry
  • Compare weight normalization to batch normalization: what weight norm avoids (batch statistics, train/eval mismatch), what it loses (automatic scale adaptation), and when each is preferable
  • Explain the spectral norm of a matrix — its relationship to the largest singular value and the Lipschitz constant — and describe how power iteration estimates it efficiently
  • Identify where spectral normalization is applied in GAN discriminators and explain why constraining the Lipschitz constant is required by the Wasserstein distance objective

Weight Normalization

Weight normalization (Salimans & Kingma, 2016) takes a different approach from all the normalizations seen so far. Instead of normalizing activations at runtime, it reparameterizes the weight vectors themselves.

For each output neuron ii, replace the weight vector wi\mathbf{w}_i with a magnitude scalar gig_i and a direction vector vi\mathbf{v}_i:

wi=givivi\mathbf{w}_i = g_i \cdot \frac{\mathbf{v}_i}{\|\mathbf{v}_i\|}

The network now has two parameters per weight vector: giRg_i \in \mathbb{R} (scalar magnitude) and viRnin\mathbf{v}_i \in \mathbb{R}^{n_{\text{in}}} (direction, unnormalized). The actual weight used in the forward pass is always the unit-normalized direction scaled by gig_i.

Crucially: weight normalization adds no computation to the forward pass beyond the reparameterization itself. There is no batch statistic computation, no running average, no train/eval mode difference.


Why Reparameterizing Helps Optimization

Consider gradient descent on the original weight w\mathbf{w}. The learning dynamics couple magnitude and direction — a step that adjusts the direction of w\mathbf{w} also changes its magnitude, and vice versa.

With weight norm, the gradients w.r.t. gg and v\mathbf{v} are:

gL=wLvv\nabla_g \mathcal{L} = \frac{\nabla_{\mathbf{w}} \mathcal{L} \cdot \mathbf{v}}{\|\mathbf{v}\|}

vL=gv(wLgLgv^)\nabla_{\mathbf{v}} \mathcal{L} = \frac{g}{\|\mathbf{v}\|}\left(\nabla_{\mathbf{w}} \mathcal{L} - \frac{\nabla_g \mathcal{L}}{g} \cdot \hat{\mathbf{v}}\right)

The gradient w.r.t. v\mathbf{v} is automatically orthogonalized: the component of wL\nabla_{\mathbf{w}}\mathcal{L} in the direction of v^\hat{\mathbf{v}} is subtracted out. Updates to v\mathbf{v} only change the direction of w\mathbf{w}, not its magnitude.

This decoupling produces more well-conditioned optimization — updates to magnitude and direction do not interfere with each other.


Weight Norm + Mean-Only BatchNorm

Salimans & Kingma also proposed pairing weight normalization with mean-only batch normalization: subtract the batch mean but do not divide by the batch variance. This is lighter-weight than full BN and avoids the variance estimation noise that is problematic for small batches.

The combination was shown to match full BN in some settings while being more suitable for online / generative modeling tasks.


Weight Norm vs. Batch Norm

Property Weight Norm Batch Norm
Normalizes Weight vectors Activations (per-feature, across batch)
Batch-dependent No Yes
Train/eval difference No Yes (running stats)
Overhead Reparameterization only Batch statistics per layer
Adapts to data scale No Yes (batch statistics track data)
Works with batch size 1 Yes No (needs LayerNorm instead)
Common use Generative models, RL CNNs, supervised vision

Spectral Normalization

Spectral normalization (Miyato et al., 2018) constrains the weight matrix at the level of its largest singular value.

The Spectral Norm

For a matrix WRm×nW \in \mathbb{R}^{m \times n}, the spectral norm is its largest singular value:

σ1(W)=maxx0Wx2x2\sigma_1(W) = \max_{\mathbf{x} \neq 0} \frac{\|W\mathbf{x}\|_2}{\|\mathbf{x}\|_2}

This equals the induced 2\ell_2 operator norm — the maximum factor by which the linear map xWx\mathbf{x} \mapsto W\mathbf{x} can amplify any input vector. It bounds the layer's Lipschitz constant:

WxWy2σ1(W)xy2\|W\mathbf{x} - W\mathbf{y}\|_2 \leq \sigma_1(W)\|\mathbf{x} - \mathbf{y}\|_2

The Normalization

Spectral normalization divides WW by its spectral norm after each gradient step:

W^=W/σ1(W)\hat{W} = W / \sigma_1(W)

This ensures σ1(W^)=1\sigma_1(\hat{W}) = 1: the normalized layer is Lipschitz-1. If all layers are Lipschitz-1, the entire network has Lipschitz constant at most 1 (product of per-layer constants).

Efficient Estimation via Power Iteration

Computing σ1(W)\sigma_1(W) exactly via SVD is O(min(m,n)mn)O(\min(m,n) \cdot mn) — too expensive to do at every step. Instead, power iteration provides a fast estimate:

Initialize: random unit vector u~Rm\tilde{\mathbf{u}} \in \mathbb{R}^m, v~Rn\tilde{\mathbf{v}} \in \mathbb{R}^n

At each training step (1 iteration is typically enough): v~Wu~Wu~u~Wv~Wv~\tilde{\mathbf{v}} \leftarrow \frac{W^\top \tilde{\mathbf{u}}}{\|W^\top \tilde{\mathbf{u}}\|} \qquad \tilde{\mathbf{u}} \leftarrow \frac{W \tilde{\mathbf{v}}}{\|W \tilde{\mathbf{v}}\|} σ1(W)u~Wv~\sigma_1(W) \approx \tilde{\mathbf{u}}^\top W \tilde{\mathbf{v}}

The vectors u~\tilde{\mathbf{u}} and v~\tilde{\mathbf{v}} converge to the left and right singular vectors corresponding to σ1\sigma_1. They are stored between steps so each update starts from the previous estimate — one iteration per step is sufficient because WW changes slowly.

Why GANs Need Lipschitz Constraints

The Wasserstein GAN (Arjovsky et al., 2017) defines the discriminator's objective as:

LWGAN=Expdata[D(x)]Expg[D(x)]\mathcal{L}_{\text{WGAN}} = \mathbb{E}_{x \sim p_{\text{data}}}[D(x)] - \mathbb{E}_{x \sim p_g}[D(x)]

For this objective to equal the true Wasserstein-1 distance, the discriminator DD must be a 1-Lipschitz function (Kantorovich-Rubinstein duality). WGAN-GP enforces this with a gradient penalty; spectral normalization enforces it directly via weight normalization — and does so without adding any loss term or hyperparameter to tune.

Spectral normalization stabilizes GAN training dramatically and is standard in modern image generators (BigGAN, StyleGAN discriminator).


PyTorch and TensorFlow

PyTorchtorch.nn.utils.weight_norm and spectral_norm:

import torch
import torch.nn as nn
from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm

# Weight normalization: reparameterize w = g * v / ||v||
linear = nn.Linear(64, 32)
wn = weight_norm(linear, name='weight', dim=0)
# wn now exposes 'weight_g' (magnitude) and 'weight_v' (direction)
print([n for n, _ in wn.named_parameters()])  # ['weight_g', 'weight_v', 'bias']

x   = torch.randn(8, 64)
out = wn(x)                     # reconstructs weight = g * v/||v|| each forward pass

remove_weight_norm(wn)          # fuse back into a plain weight before saving/inference

# Spectral normalization: W_hat = W / sigma_1(W)
conv   = nn.Conv2d(64, 128, 3, padding=1)
sn_conv = spectral_norm(conv, n_power_iterations=1)
# sn_conv.weight_u / weight_v store the power-iteration vectors between steps

x   = torch.randn(4, 64, 8, 8)
out = sn_conv(x)                # sigma_1 estimated and applied on every forward pass

# Typical spectrally-normalized GAN discriminator block
class SNResBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv1 = spectral_norm(nn.Conv2d(in_ch, out_ch, 3, padding=1))
        self.conv2 = spectral_norm(nn.Conv2d(out_ch, out_ch, 3, padding=1))
        self.skip  = spectral_norm(nn.Conv2d(in_ch, out_ch, 1))

    def forward(self, x):
        h = torch.relu(self.conv1(x))
        h = self.conv2(h)
        return torch.relu(h + self.skip(x))

TensorFlow / Keras:

import tensorflow as tf

# SpectralNormalization wrapper — built-in from TF 2.6+
dense   = tf.keras.layers.Dense(64)
sn_dense = tf.keras.layers.SpectralNormalization(dense)

x   = tf.random.normal((8, 128))
out = sn_dense(x)   # weight divided by estimated largest singular value

# Wrap Conv2D the same way
conv    = tf.keras.layers.Conv2D(128, 3, padding='same')
sn_conv = tf.keras.layers.SpectralNormalization(conv)

# Manual weight normalization
class WeightNormDense(tf.keras.layers.Layer):
    def __init__(self, units: int, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        fan_in = input_shape[-1]
        self.v = self.add_weight('v', shape=(fan_in, self.units))
        self.g = self.add_weight('g', shape=(self.units,), initializer='ones')

    def call(self, x):
        v_norm = tf.linalg.norm(self.v, axis=0, keepdims=True)  # (1, units)
        w = self.g * self.v / v_norm                             # (fan_in, units)
        return x @ w