Supplement · Weight Initialization

Random Distribution Initializations

12 min read
By the end of this reading you will be able to:
  • Apply uniform_, normal_, trunc_normal_, and sparse_ in PyTorch and their TF equivalents with appropriate scale parameters
  • Explain why naïve N(0, 1) initialization fails in deep networks and derive the correct standard deviation for a stable forward pass
  • Distinguish truncated normal from standard normal initialization and state the practical benefit of truncating at ±2σ

The Scale Problem

Drawing weights from a random distribution breaks symmetry — different neurons start with different values, so they learn different features. But the choice of distribution and, crucially, its scale determines whether the network trains at all.

For a single linear layer z=Wxz = Wx with nn inputs, if wijN(0,σ2)w_{ij} \sim \mathcal{N}(0, \sigma^2) independently:

Var(zj)=nσ2Var(x)\text{Var}(z_j) = n \cdot \sigma^2 \cdot \text{Var}(x)

To preserve variance across layers (Var(z)Var(x)\text{Var}(z) \approx \text{Var}(x)):

σ2=1n\sigma^2 = \frac{1}{n}

This is the key constraint. The three generic random initializers below are agnostic to activation functions — they let you set σ\sigma (or the range) directly. The variance-scaling initializers in the next two readings derive σ\sigma automatically from the layer dimensions and activation function.

uniform — Uniform Distribution

torch.nn.init.uniform_(tensor, a=0, b=1) / tf.keras.initializers.RandomUniform(minval=-0.05, maxval=0.05)

Draws each weight independently from U[a,b]\mathcal{U}[a, b]. The variance of U[a,b]\mathcal{U}[a, b] is (ba)2/12(b - a)^2 / 12, so to achieve Var(w)=1/n\text{Var}(w) = 1/n you need ba=12/n=23/nb - a = \sqrt{12/n} = 2\sqrt{3/n}.

The PyTorch default (a=0, b=1) is almost never appropriate directly — it initializes weights to values in [0,1][0, 1], which is far too large for most layers and produces positive-only weights (breaking zero-centering). Always set a and b explicitly.

PyTorch:

import torch
import torch.nn as nn

w = torch.empty(256, 128)

# Naïve (almost never correct)
nn.init.uniform_(w, a=0, b=1)

# Correct scale for stable forward pass with fan_in=128
import math
bound = math.sqrt(3.0 / 128)   # ≈ 0.153
nn.init.uniform_(w, a=-bound, b=bound)

TensorFlow:

import tensorflow as tf
import math

fan_in = 128
bound = math.sqrt(3.0 / fan_in)

uniform_init = tf.keras.initializers.RandomUniform(minval=-bound, maxval=bound)
w = uniform_init(shape=(256, 128))

# In a layer
dense = tf.keras.layers.Dense(256,
    kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.05, maxval=0.05))

normal — Gaussian Distribution

torch.nn.init.normal_(tensor, mean=0, std=1) / tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.05)

Draws weights from N(μ,σ2)\mathcal{N}(\mu, \sigma^2). The natural choice for weight initialization — the Gaussian is isotropic, zero-centered by default, and has well-understood properties under linear transformations.

The PyTorch default std=1 and Keras default stddev=0.05 are both reasonable only if you know the fan-in. For a 128-input layer with variance stability: σ=1/1280.088\sigma = 1/\sqrt{128} \approx 0.088.

PyTorch:

import math

w = torch.empty(256, 128)

# Explicit scale
nn.init.normal_(w, mean=0, std=0.01)   # small std — common in early practice

# Variance-stable scale
std = 1.0 / math.sqrt(128)
nn.init.normal_(w, mean=0, std=std)

TensorFlow:

fan_in = 128
std = 1.0 / math.sqrt(fan_in)

normal_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=std)
w = normal_init(shape=(256, 128))

trunc_normal — Truncated Normal

torch.nn.init.trunc_normal_(tensor, mean=0, std=1, a=-2, b=2) / tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.05)

Draws from N(μ,σ2)\mathcal{N}(\mu, \sigma^2) but rejects and resamples any value outside [a,b][a, b] (PyTorch: defaults to [mean2σ,mean+2σ][\text{mean} - 2\sigma, \text{mean} + 2\sigma]). This eliminates extreme outliers that can cause spikes in activation variance early in training.

Truncated normal is the default distribution in TF's VarianceScaling initializer, and therefore the default underlying distribution for GlorotNormal, HeNormal, and LecunNormal in Keras. It is increasingly preferred over standard normal for Transformer architectures (ViT, BERT) because it prevents initialization-time outlier activations.

PyTorch:

w = torch.empty(256, 128)

# Default: clips at ±2 std from mean
nn.init.trunc_normal_(w, mean=0, std=0.02)   # clips at [-0.04, 0.04]

# Custom clip range
nn.init.trunc_normal_(w, mean=0, std=0.02, a=-0.06, b=0.06)

# Common in ViT-style models
pos_embed = torch.empty(1, 196, 768)
nn.init.trunc_normal_(pos_embed, std=0.02)

TensorFlow:

# TruncatedNormal clips at ±2 std automatically
trunc_init = tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02)
w = trunc_init(shape=(256, 128))

dense = tf.keras.layers.Dense(256,
    kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02))

sparse — Sparse Random Initialization

torch.nn.init.sparse_(tensor, sparsity, std=0.01) (PyTorch only)

Initializes the weight matrix so that a fraction sparsity of the weights in each column are set to exactly zero; the remaining weights are drawn from N(0,std2)\mathcal{N}(0, \text{std}^2). This promotes sparse connectivity at initialization and is inspired by biological neural networks where each neuron connects to only a fraction of upstream neurons.

Sparse initialization is primarily used with very large weight matrices (e.g., wide feedforward layers) where full connectivity at initialization may be too rich a starting point.

PyTorch:

w = torch.empty(1000, 500)

# 90% of weights in each column = 0; rest from N(0, 0.01^2)
nn.init.sparse_(w, sparsity=0.9, std=0.01)

# Inspect sparsity
zero_fraction = (w == 0).float().mean()
print(f"Zero fraction: {zero_fraction:.3f}")   # ≈ 0.900

TensorFlow: No built-in equivalent. A similar effect:

import numpy as np

def sparse_initializer(shape, sparsity=0.9, std=0.01, dtype=tf.float32):
    values = np.random.normal(0, std, shape)
    mask = np.random.uniform(0, 1, shape) < sparsity
    values[mask] = 0.0
    return tf.constant(values, dtype=dtype)

Choosing a Standard Deviation

If you use uniform or normal initialization directly (rather than delegating to Xavier or He), use this reference:

Goal σ\sigma for normal bound for uniform
Variance-stable (forward) 1/fan_in1 / \sqrt{\text{fan\_in}} 3/fan_in\sqrt{3 / \text{fan\_in}}
Variance-stable (backward) 1/fan_out1 / \sqrt{\text{fan\_out}} 3/fan_out\sqrt{3 / \text{fan\_out}}
LeCun (SELU) 1/fan_in1 / \sqrt{\text{fan\_in}} same as forward
Xavier (sigmoid/tanh) 2/(fan_in+fan_out)\sqrt{2 / (\text{fan\_in} + \text{fan\_out})} 6/(fan_in+fan_out)\sqrt{6 / (\text{fan\_in} + \text{fan\_out})}
He (ReLU) 2/fan_in\sqrt{2 / \text{fan\_in}} 6/fan_in\sqrt{6 / \text{fan\_in}}

The last two rows are precisely what Xavier and He compute automatically — the readings that follow derive them from first principles.

References
Dosovitskiy et al. (2020) — An Image is Worth 16×16 Words (ViT) — Uses truncated normal with std=0.02 for all weight and positional embedding initializations
Martens (2010) — Deep Learning via Hessian-free Optimization — Introduced sparse initialization motivated by biological connectivity patterns