He / Kaiming & LeCun Initialization
- Explain why ReLU violates Xavier's linear approximation assumption and derive the corrected variance formula Var(W) = 2 / fan_in
- Distinguish fan_in mode from fan_out mode in kaiming_normal_ and state when each is appropriate
- Apply kaiming_normal_ and kaiming_uniform_ in PyTorch and HeNormal / HeUniform in TensorFlow for ReLU and LeakyReLU networks
- Identify when to use LeCun initialization instead of He and explain its role in SELU self-normalization
Why Xavier Fails for ReLU
Xavier's derivation assumed that the activation function is approximately linear at , so it does not change the variance of the signal passing through it.
ReLU violates this assumption: it sets exactly half its inputs to zero (for a zero-mean pre-activation, ). This means ReLU halves the effective fan-in:
With Xavier initialization (), each ReLU layer halves the signal variance. In a 50-layer network, the signal is times its original magnitude. Training stalls.
The He Correction
He et al. (2015) corrected the variance formula specifically for ReLU. Starting from the forward-pass constraint:
and accounting for ReLU's halving effect:
Note that Xavier uses ; He uses . For a square layer where fan-in equals fan-out, He initialization uses exactly twice the variance of Xavier.
He Normal
torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
tf.keras.initializers.HeNormal()
where is the slope of the negative part of the activation (0 for ReLU, 0.01 for LeakyReLU with default slope).
PyTorch:
import torch
import torch.nn as nn
w = torch.empty(256, 128)
# He Normal for ReLU (a=0, mode='fan_in')
nn.init.kaiming_normal_(w, a=0, mode='fan_in', nonlinearity='relu')
# std = sqrt(2 / 128) ≈ 0.125
# He Normal for LeakyReLU with slope 0.01
nn.init.kaiming_normal_(w, a=0.01, mode='fan_in', nonlinearity='leaky_relu')
# Apply to a ReLU network
def init_he(m):
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
model = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.Linear(128, 10)
)
model.apply(init_he)
TensorFlow:
import tensorflow as tf
# HeNormal for Dense
dense = tf.keras.layers.Dense(256,
kernel_initializer=tf.keras.initializers.HeNormal(),
bias_initializer='zeros')
# HeNormal for Conv2D
conv = tf.keras.layers.Conv2D(64, 3, padding='same',
kernel_initializer=tf.keras.initializers.HeNormal())
# Standalone
he_n = tf.keras.initializers.HeNormal()
w = he_n(shape=(128, 256))
He Uniform
torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')
tf.keras.initializers.HeUniform()
PyTorch:
nn.init.kaiming_uniform_(w, a=0, mode='fan_in', nonlinearity='relu')
# bound = sqrt(3) * sqrt(2/128) ≈ 0.2165
TensorFlow:
dense = tf.keras.layers.Dense(256,
kernel_initializer=tf.keras.initializers.HeUniform())
fan_in vs fan_out Mode
The mode parameter controls which dimension the variance is normalized by:
| Mode | Formula | Stabilizes |
|---|---|---|
fan_in (default) |
Forward pass — activation variance | |
fan_out |
Backward pass — gradient variance |
fan_in is the standard choice for most feedforward networks: it ensures activations don't explode or collapse going forward. fan_out may be preferable when the backward pass is the primary concern (e.g., in very deep networks with many more outputs than inputs per layer).
# fan_in: normalize by inputs (stabilizes forward pass)
nn.init.kaiming_normal_(w, mode='fan_in', nonlinearity='relu')
# fan_out: normalize by outputs (stabilizes backward pass)
nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
The a Parameter — Generalizing to LeakyReLU
The derivation of He initialization assumed ReLU zeroes exactly half the inputs. For LeakyReLU with negative slope , the effective variance is:
The corrected formula replaces the factor 2 with :
For ReLU: , so the formula reduces to . For LeakyReLU with slope 0.01: , which barely changes the result.
# ReLU: a=0
nn.init.kaiming_normal_(w, a=0, nonlinearity='relu')
# LeakyReLU with slope 0.2
nn.init.kaiming_normal_(w, a=0.2, nonlinearity='leaky_relu')
LeCun Initialization
tf.keras.initializers.LecunNormal() / tf.keras.initializers.LecunUniform()
LeCun initialization uses — half the variance of He. It was designed for SELU (Scaled Exponential Linear Unit), which is a self-normalizing activation: with LeCun initialization, a deep SELU network maintains approximately unit Gaussian activations throughout training without BatchNorm.
Note that LeCun normal uses truncated normal internally in Keras.
PyTorch has no built-in LeCun initializer; implement it directly:
import math
w = torch.empty(256, 128)
std = math.sqrt(1.0 / 128) # LeCun normal
nn.init.normal_(w, mean=0, std=std)
bound = math.sqrt(3.0 / 128) # LeCun uniform
nn.init.uniform_(w, a=-bound, b=bound)
TensorFlow:
# LeCun Normal — use with SELU
dense_selu = tf.keras.layers.Dense(256, activation='selu',
kernel_initializer=tf.keras.initializers.LecunNormal())
# LeCun Uniform
dense_selu_u = tf.keras.layers.Dense(256, activation='selu',
kernel_initializer=tf.keras.initializers.LecunUniform())
# Standalone
lecun_n = tf.keras.initializers.LecunNormal()
w = lecun_n(shape=(128, 256))
VarianceScaling — The Unified Abstraction
TensorFlow's VarianceScaling initializer is the common base for GlorotUniform, GlorotNormal, HeNormal, HeUniform, LecunNormal, and LecunUniform. Understanding it clarifies the relationship between all three:
# GlorotUniform = VarianceScaling(scale=1, mode='fan_avg', distribution='uniform')
# GlorotNormal = VarianceScaling(scale=1, mode='fan_avg', distribution='truncated_normal')
# HeNormal = VarianceScaling(scale=2, mode='fan_in', distribution='truncated_normal')
# HeUniform = VarianceScaling(scale=2, mode='fan_in', distribution='uniform')
# LecunNormal = VarianceScaling(scale=1, mode='fan_in', distribution='truncated_normal')
# LecunUniform = VarianceScaling(scale=1, mode='fan_in', distribution='uniform')
# Directly:
vs = tf.keras.initializers.VarianceScaling(
scale=2.0,
mode='fan_in',
distribution='truncated_normal'
) # equivalent to HeNormal
The scale parameter directly multiplies the variance: scale=1 → LeCun, scale=2 → He, scale=1 with fan_avg → Xavier.
Comparison
| Initializer | Activation | PyTorch | TF | |
|---|---|---|---|---|
| Xavier normal | sigmoid, tanh, linear | xavier_normal_ |
GlorotNormal |
|
| Xavier uniform | same | same | xavier_uniform_ |
GlorotUniform |
| He normal | ReLU, LeakyReLU | kaiming_normal_ |
HeNormal |
|
| He uniform | same | same | kaiming_uniform_ |
HeUniform |
| LeCun normal | SELU | (manual) | LecunNormal |
|
| LeCun uniform | same | SELU | (manual) | LecunUniform |