Weight Normalization and Spectral Normalization
- Describe the weight normalization reparameterization w = g·v/‖v‖ and explain how decoupling magnitude from direction simplifies the optimization geometry
- Compare weight normalization to batch normalization: what weight norm avoids (batch statistics, train/eval mismatch), what it loses (automatic scale adaptation), and when each is preferable
- Explain the spectral norm of a matrix — its relationship to the largest singular value and the Lipschitz constant — and describe how power iteration estimates it efficiently
- Identify where spectral normalization is applied in GAN discriminators and explain why constraining the Lipschitz constant is required by the Wasserstein distance objective
Weight Normalization
Weight normalization (Salimans & Kingma, 2016) takes a different approach from all the normalizations seen so far. Instead of normalizing activations at runtime, it reparameterizes the weight vectors themselves.
For each output neuron , replace the weight vector with a magnitude scalar and a direction vector :
The network now has two parameters per weight vector: (scalar magnitude) and (direction, unnormalized). The actual weight used in the forward pass is always the unit-normalized direction scaled by .
Crucially: weight normalization adds no computation to the forward pass beyond the reparameterization itself. There is no batch statistic computation, no running average, no train/eval mode difference.
Why Reparameterizing Helps Optimization
Consider gradient descent on the original weight . The learning dynamics couple magnitude and direction — a step that adjusts the direction of also changes its magnitude, and vice versa.
With weight norm, the gradients w.r.t. and are:
The gradient w.r.t. is automatically orthogonalized: the component of in the direction of is subtracted out. Updates to only change the direction of , not its magnitude.
This decoupling produces more well-conditioned optimization — updates to magnitude and direction do not interfere with each other.
Weight Norm + Mean-Only BatchNorm
Salimans & Kingma also proposed pairing weight normalization with mean-only batch normalization: subtract the batch mean but do not divide by the batch variance. This is lighter-weight than full BN and avoids the variance estimation noise that is problematic for small batches.
The combination was shown to match full BN in some settings while being more suitable for online / generative modeling tasks.
Weight Norm vs. Batch Norm
| Property | Weight Norm | Batch Norm |
|---|---|---|
| Normalizes | Weight vectors | Activations (per-feature, across batch) |
| Batch-dependent | No | Yes |
| Train/eval difference | No | Yes (running stats) |
| Overhead | Reparameterization only | Batch statistics per layer |
| Adapts to data scale | No | Yes (batch statistics track data) |
| Works with batch size 1 | Yes | No (needs LayerNorm instead) |
| Common use | Generative models, RL | CNNs, supervised vision |
Spectral Normalization
Spectral normalization (Miyato et al., 2018) constrains the weight matrix at the level of its largest singular value.
The Spectral Norm
For a matrix , the spectral norm is its largest singular value:
This equals the induced operator norm — the maximum factor by which the linear map can amplify any input vector. It bounds the layer's Lipschitz constant:
The Normalization
Spectral normalization divides by its spectral norm after each gradient step:
This ensures : the normalized layer is Lipschitz-1. If all layers are Lipschitz-1, the entire network has Lipschitz constant at most 1 (product of per-layer constants).
Efficient Estimation via Power Iteration
Computing exactly via SVD is — too expensive to do at every step. Instead, power iteration provides a fast estimate:
Initialize: random unit vector ,
At each training step (1 iteration is typically enough):
The vectors and converge to the left and right singular vectors corresponding to . They are stored between steps so each update starts from the previous estimate — one iteration per step is sufficient because changes slowly.
Why GANs Need Lipschitz Constraints
The Wasserstein GAN (Arjovsky et al., 2017) defines the discriminator's objective as:
For this objective to equal the true Wasserstein-1 distance, the discriminator must be a 1-Lipschitz function (Kantorovich-Rubinstein duality). WGAN-GP enforces this with a gradient penalty; spectral normalization enforces it directly via weight normalization — and does so without adding any loss term or hyperparameter to tune.
Spectral normalization stabilizes GAN training dramatically and is standard in modern image generators (BigGAN, StyleGAN discriminator).
PyTorch and TensorFlow
PyTorch — torch.nn.utils.weight_norm and spectral_norm:
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm
# Weight normalization: reparameterize w = g * v / ||v||
linear = nn.Linear(64, 32)
wn = weight_norm(linear, name='weight', dim=0)
# wn now exposes 'weight_g' (magnitude) and 'weight_v' (direction)
print([n for n, _ in wn.named_parameters()]) # ['weight_g', 'weight_v', 'bias']
x = torch.randn(8, 64)
out = wn(x) # reconstructs weight = g * v/||v|| each forward pass
remove_weight_norm(wn) # fuse back into a plain weight before saving/inference
# Spectral normalization: W_hat = W / sigma_1(W)
conv = nn.Conv2d(64, 128, 3, padding=1)
sn_conv = spectral_norm(conv, n_power_iterations=1)
# sn_conv.weight_u / weight_v store the power-iteration vectors between steps
x = torch.randn(4, 64, 8, 8)
out = sn_conv(x) # sigma_1 estimated and applied on every forward pass
# Typical spectrally-normalized GAN discriminator block
class SNResBlock(nn.Module):
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.conv1 = spectral_norm(nn.Conv2d(in_ch, out_ch, 3, padding=1))
self.conv2 = spectral_norm(nn.Conv2d(out_ch, out_ch, 3, padding=1))
self.skip = spectral_norm(nn.Conv2d(in_ch, out_ch, 1))
def forward(self, x):
h = torch.relu(self.conv1(x))
h = self.conv2(h)
return torch.relu(h + self.skip(x))
TensorFlow / Keras:
import tensorflow as tf
# SpectralNormalization wrapper — built-in from TF 2.6+
dense = tf.keras.layers.Dense(64)
sn_dense = tf.keras.layers.SpectralNormalization(dense)
x = tf.random.normal((8, 128))
out = sn_dense(x) # weight divided by estimated largest singular value
# Wrap Conv2D the same way
conv = tf.keras.layers.Conv2D(128, 3, padding='same')
sn_conv = tf.keras.layers.SpectralNormalization(conv)
# Manual weight normalization
class WeightNormDense(tf.keras.layers.Layer):
def __init__(self, units: int, **kwargs):
super().__init__(**kwargs)
self.units = units
def build(self, input_shape):
fan_in = input_shape[-1]
self.v = self.add_weight('v', shape=(fan_in, self.units))
self.g = self.add_weight('g', shape=(self.units,), initializer='ones')
def call(self, x):
v_norm = tf.linalg.norm(self.v, axis=0, keepdims=True) # (1, units)
w = self.g * self.v / v_norm # (fan_in, units)
return x @ w