Supplement · Normalization in Deep Learning

Layer Norm, RMSNorm, and the Pre-Norm / Post-Norm Debate

15 min read
By the end of this reading you will be able to:
  • Trace the LayerNorm forward pass — identifying what dimension is normalized, why no running statistics are needed, and why it works at batch size 1
  • Distinguish pre-norm (norm before sub-layer) from post-norm (norm after residual add) placement in transformer blocks, and explain why pre-norm training is more stable for deep transformers
  • Explain RMSNorm's simplification over LayerNorm — dropping mean centering — and state the speed advantage and which modern LLMs use it
  • Explain DeepNorm's scaled residual connection — y = α·x + SubLayer(x) — and state why scaling the residual allows stable pre-training of very deep (1000+ layer) transformers

Layer Normalization

Layer normalization (Ba et al., 2016) addresses the two main limitations of BatchNorm: small-batch instability and the inability to normalize recurrent networks cleanly.

The key change: normalize over the feature dimension of each example independently, rather than across the batch dimension.

For a vector xRd\mathbf{x} \in \mathbb{R}^d (the activations of one example at one layer):

μ=1dj=1dxjσ2=1dj=1d(xjμ)2\mu = \frac{1}{d}\sum_{j=1}^d x_j \qquad \sigma^2 = \frac{1}{d}\sum_{j=1}^d (x_j - \mu)^2

x^j=xjμσ2+ϵyj=γjx^j+βj\hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \epsilon}} \qquad y_j = \gamma_j\hat{x}_j + \beta_j

What is the same as BatchNorm: learned scale γ\gamma and shift β\beta per feature; ϵ\epsilon for numerical stability.

What is different:

  • μ\mu and σ2\sigma^2 are computed per-example, not per-batch — they depend on nothing outside the current input
  • Train and inference are identical — no running statistics, no model.eval() switch needed
  • Works with batch size 1
  • Works inside RNNs: applied at each timestep independently

Where LayerNorm Is Applied in Transformers

Every transformer block contains two sub-layers (self-attention and FFN), each wrapped with a residual connection and LayerNorm. The placement of LayerNorm relative to these components has a significant impact on training stability.

Post-Norm (Original Transformer, Vaswani et al. 2017)

y=LayerNorm(x+SubLayer(x))\mathbf{y} = \text{LayerNorm}(\mathbf{x} + \text{SubLayer}(\mathbf{x}))

The residual is added first, then normalized. At initialization, the residual branch is near-zero (small weights), so the output x+SubLayer(x)x\mathbf{x} + \text{SubLayer}(\mathbf{x}) \approx \mathbf{x} — LayerNorm essentially normalizes the pass-through signal.

Problem: as the network depth increases, gradients must flow through many LayerNorm operations before reaching early layers. This can cause instability at the start of training — the network needs learning rate warmup to survive the first few thousand steps.

Pre-Norm (Most Modern Transformers)

y=x+SubLayer(LayerNorm(x))\mathbf{y} = \mathbf{x} + \text{SubLayer}(\text{LayerNorm}(\mathbf{x}))

Normalization is applied to the input before the sub-layer, and the raw (unnormalized) residual is added. The skip path carries the unnormalized signal directly.

Why it is more stable: at initialization, SubLayer(LayerNorm(x))\text{SubLayer}(\text{LayerNorm}(\mathbf{x})) starts near zero, so yx\mathbf{y} \approx \mathbf{x} — the block is approximately an identity. Gradients can flow freely through the residual path regardless of depth. Pre-norm networks train without warmup and are more stable for very deep stacks (24, 48, 96 layers).

Trade-off: pre-norm networks have been observed to converge to slightly worse final accuracy than well-tuned post-norm networks on some tasks. Post-norm remains preferred when absolute final quality matters more than training stability (e.g., some academic benchmarks).

All major LLM families (GPT-3, LLaMA, Mistral, Gemma, Falcon) use pre-norm.


RMSNorm

RMSNorm (Zhang & Sennrich, 2019) simplifies LayerNorm by removing the mean-centering step:

RMS(x)=1dj=1dxj2+ϵ\text{RMS}(\mathbf{x}) = \sqrt{\frac{1}{d}\sum_{j=1}^d x_j^2 + \epsilon}

yj=xjRMS(x)γjy_j = \frac{x_j}{\text{RMS}(\mathbf{x})} \cdot \gamma_j

What is removed: the mean subtraction (xjμx_j - \mu) and the bias parameter β\beta.

Motivation: the centering step in LayerNorm accounts for only a small fraction of its representational value — most of the benefit comes from scale normalization. Removing it is an approximation, but one that holds well in practice for language models.

Performance gain: RMSNorm requires roughly 15–20% fewer floating-point operations than LayerNorm (no mean accumulation pass, no subtraction, no β\beta). For very large models trained on large clusters, this translates to meaningful throughput improvements.

Used in: LLaMA (all versions), Mistral, Gemma, Falcon — essentially all modern decoder-only LLMs after 2023. It is now the default for new transformer architectures.


DeepNorm

DeepNorm (Wang et al., 2022) makes a different intervention: instead of changing what LayerNorm computes, it changes where the residual is scaled.

y=LayerNorm(αx+SubLayer(x))\mathbf{y} = \text{LayerNorm}(\alpha\,\mathbf{x} + \text{SubLayer}(\mathbf{x}))

The residual x\mathbf{x} is scaled by α>1\alpha > 1 before LayerNorm. The sub-layer weights WW are also initialized at a fraction β<1\beta < 1 of their normal scale.

Why this works: with α>1\alpha > 1, the expected magnitude of the residual is larger relative to the sub-layer output at initialization. The normalized output is dominated by the pass-through signal — the sub-layer starts as a small perturbation. This is a precisely controlled version of the pre-norm stability argument.

DeepNorm enabled training transformers with over 1000 layers (DeepNet paper) without divergence, achieving stable gradients that are bounded throughout training — a formal theoretical guarantee.


LayerNorm in Non-Transformer Architectures

RNNs and LSTMs: apply LayerNorm to the pre-activation at each timestep: ht=LN(Whht1+Wxxt)\mathbf{h}_t = \text{LN}(W_h\mathbf{h}_{t-1} + W_x\mathbf{x}_t). Works naturally because each timestep is a separate example — no batch dimension needed.

Graph Neural Networks: LayerNorm on node features, where each node is a separate example.

Vision Transformers: LayerNorm on patch embeddings — same as language transformers.


PyTorch and TensorFlow

PyTorchnn.LayerNorm, nn.RMSNorm, and a pre-norm block:

import torch
import torch.nn as nn

# LayerNorm: normalized_shape = dimensions to normalize over (the last D dims)
ln = nn.LayerNorm(normalized_shape=512, eps=1e-5)
x  = torch.randn(2, 10, 512)     # (batch, seq_len, d_model)
out = ln(x)                      # normalizes over last dim

# RMSNorm — built-in from PyTorch 2.4
rms = nn.RMSNorm(512)
out_rms = rms(x)                 # no mean centering; ~15-20% fewer ops

# Custom RMSNorm for older PyTorch
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))   # learned gamma only

    def forward(self, x):
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return x / rms * self.weight

# Pre-norm transformer block: y = x + SubLayer(LN(x))
class PreNormBlock(nn.Module):
    def __init__(self, d_model: int, nhead: int, ffn_dim: int):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn  = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.ffn   = nn.Sequential(
            nn.Linear(d_model, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, d_model)
        )

    def forward(self, x):
        n = self.norm1(x)
        x = x + self.attn(n, n, n)[0]   # un-normalized x in the residual
        x = x + self.ffn(self.norm2(x))
        return x

TensorFlow / Keras:

import tensorflow as tf

# LayerNorm: axis=-1 normalizes the last dimension
ln = tf.keras.layers.LayerNormalization(axis=-1, epsilon=1e-5)
x  = tf.random.normal((2, 10, 512))
out = ln(x)   # shape (2, 10, 512)

# Custom RMSNorm (no built-in in TF/Keras)
class RMSNorm(tf.keras.layers.Layer):
    def __init__(self, dim: int, eps: float = 1e-6, **kwargs):
        super().__init__(**kwargs)
        self.eps, self.dim = eps, dim

    def build(self, input_shape):
        self.weight = self.add_weight(
            shape=(self.dim,), initializer='ones', trainable=True, name='weight'
        )

    def call(self, x):
        rms = tf.sqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps)
        return x / rms * self.weight