Supplement · Normalization in Deep Learning

Normalization Techniques in TensorFlow

Colab Notebook · ~45 min
Google Colab Notebook
Normalization Techniques in TensorFlow
Python · ~45 min
Open in Colab
Lab Objectives
1
Implement BatchNormalization from scratch as a tf.keras.layers.Layer, handling the training= flag correctly in both model.fit and custom GradientTape loops
2
Implement LayerNorm and RMSNorm as custom Keras layers; verify parity with tf.keras.layers.LayerNormalization and benchmark throughput on long sequences
3
Apply tf.keras.layers.SpectralNormalization to a discriminator; verify the Lipschitz constraint and compare against the PyTorch Lab 1 results
4
Use tf.keras.layers.GroupNormalization in a small detection model; reproduce the batch-size stability experiment from Lab 1 and compare BN vs GN validation curves
5
Implement AdaIN as a custom Keras layer operating on (B, H, W, C) feature tensors; verify channel statistics match the style image after the forward pass
6
Implement a FiLM (Feature-wise Linear Modulation) conditioning layer and apply it to a class-conditional image generation task

Lab Overview

TensorFlow companion to the PyTorch lab. Re-implements each normalization technique using Keras idioms, with emphasis on the training= flag, custom layer subclassing, and model.fit integration.

Key API Differences vs PyTorch

Concept PyTorch TensorFlow / Keras
BatchNorm mode model.train() / model.eval() layer(x, training=True/False)
LayerNorm nn.LayerNorm(normalized_shape) tf.keras.layers.LayerNormalization(axis=-1)
RMSNorm nn.RMSNorm(dim) (PyTorch 2.4+) Custom layer (no built-in)
GroupNorm nn.GroupNorm(G, C) tf.keras.layers.GroupNormalization(groups=G) (TF 2.11+)
SpectralNorm nn.utils.spectral_norm(layer) tf.keras.layers.SpectralNormalization(layer)
AdaIN / FiLM Custom nn.Module Custom tf.keras.layers.Layer

Sections

Section Topic Key experiment
1 BatchNorm custom layer training= flag in GradientTape loop
2 LayerNorm & RMSNorm Throughput benchmark on sequences
3 SpectralNormalization Lipschitz verification; parity with PyTorch
4 GroupNorm batch-size stability BN vs GN validation curves
5 AdaIN style transfer Channel stats verification
6 FiLM conditioning Class-conditional generation

Section 1 — BatchNorm as a Custom Keras Layer

Implement MyBatchNorm subclassing tf.keras.layers.Layer. The critical TF-specific pattern:

def call(self, x, training=False):
    if training:
        mean, var = tf.nn.moments(x, axes=[0])
        # update running stats
        self.running_mean.assign(0.9 * self.running_mean + 0.1 * mean)
        self.running_var.assign( 0.9 * self.running_var  + 0.1 * var )
    else:
        mean, var = self.running_mean, self.running_var
    return tf.nn.batch_normalization(x, mean, var, self.beta, self.gamma, self.eps)

In model.fit, Keras propagates the training flag automatically. In a custom GradientTape loop you must pass it explicitly — forgetting to do so is the TF equivalent of forgetting model.eval() in PyTorch.

Section 2 — LayerNorm and RMSNorm

Implement MyRMSNorm as a Keras layer normalizing over axis=-1. There is no tf.keras.layers.RMSNorm built-in — this is a common gap in TF vs PyTorch parity. Verify numerical agreement with a PyTorch reference on the same tensors (saved to disk).

Benchmark both on (8, 2048, 4096) tensors using tf.function compilation. RMSNorm should show ~15% throughput improvement over LayerNorm.

Section 3 — SpectralNormalization

tf.keras.layers.SpectralNormalization(layer) wraps any layer and divides its kernel by the spectral norm at each step. Apply it to every Dense layer in a discriminator-style MLP:

disc = tf.keras.Sequential([
    tf.keras.layers.SpectralNormalization(tf.keras.layers.Dense(256, activation='relu')),
    tf.keras.layers.SpectralNormalization(tf.keras.layers.Dense(256, activation='relu')),
    tf.keras.layers.SpectralNormalization(tf.keras.layers.Dense(1))
])

After training, extract each layer's kernel and verify tf.linalg.svd(kernel)[0][0] ≤ 1.0 + 1e-4.

Section 4 — GroupNorm Stability

Reproduce the batch-size experiment from Lab 1 in TF:

tf.keras.layers.GroupNormalization(groups=32)   # stable at all batch sizes
tf.keras.layers.BatchNormalization()             # degrades below ~8

Train the same detection head architecture for 5 epochs at batch sizes {1, 2, 4, 8, 16, 64}. Plot validation loss for both normalizations; confirm TF results match PyTorch Lab 1.

Section 5 — AdaIN Style Transfer

Implement AdaIN as a Keras layer that takes (content, style) as inputs. TF uses channels-last (B, H, W, C) by default:

class AdaIN(tf.keras.layers.Layer):
    def call(self, content, style):
        c_mean = tf.reduce_mean(content, axis=[1, 2], keepdims=True)
        c_std  = tf.math.reduce_std(content, axis=[1, 2], keepdims=True) + 1e-5
        s_mean = tf.reduce_mean(style,   axis=[1, 2], keepdims=True)
        s_std  = tf.math.reduce_std(style,   axis=[1, 2], keepdims=True) + 1e-5
        return s_std * (content - c_mean) / c_std + s_mean

Note the axis difference: PyTorch uses dim=[2, 3] (channels-first), TF uses axis=[1, 2] (channels-last).

Section 6 — FiLM Conditioning

FiLM (Feature-wise Linear Modulation) predicts per-channel γ(c)\gamma(c) and β(c)\beta(c) from a conditioning vector cc:

class FiLM(tf.keras.layers.Layer):
    def __init__(self, num_features):
        super().__init__()
        self.gamma_proj = tf.keras.layers.Dense(num_features)
        self.beta_proj  = tf.keras.layers.Dense(num_features)

    def call(self, x, condition):
        gamma = self.gamma_proj(condition)[:, None, None, :]  # (B,1,1,C)
        beta  = self.beta_proj(condition)[:, None, None, :]
        return gamma * x + beta

Apply FiLM conditioning to a small class-conditional image generator: the conditioning vector is a learned class embedding. Train on CIFAR-10 and verify that the model generates class-specific features by interpolating between two class embeddings.