Normalization Techniques in TensorFlow
tf.keras.layers.Layer, handling the training= flag correctly in both model.fit and custom GradientTape loops
tf.keras.layers.LayerNormalization and benchmark throughput on long sequences
tf.keras.layers.SpectralNormalization to a discriminator; verify the Lipschitz constraint and compare against the PyTorch Lab 1 results
tf.keras.layers.GroupNormalization in a small detection model; reproduce the batch-size stability experiment from Lab 1 and compare BN vs GN validation curves
(B, H, W, C) feature tensors; verify channel statistics match the style image after the forward pass
Lab Overview
TensorFlow companion to the PyTorch lab. Re-implements each normalization technique using Keras idioms, with emphasis on the training= flag, custom layer subclassing, and model.fit integration.
Key API Differences vs PyTorch
| Concept | PyTorch | TensorFlow / Keras |
|---|---|---|
| BatchNorm mode | model.train() / model.eval() |
layer(x, training=True/False) |
| LayerNorm | nn.LayerNorm(normalized_shape) |
tf.keras.layers.LayerNormalization(axis=-1) |
| RMSNorm | nn.RMSNorm(dim) (PyTorch 2.4+) |
Custom layer (no built-in) |
| GroupNorm | nn.GroupNorm(G, C) |
tf.keras.layers.GroupNormalization(groups=G) (TF 2.11+) |
| SpectralNorm | nn.utils.spectral_norm(layer) |
tf.keras.layers.SpectralNormalization(layer) |
| AdaIN / FiLM | Custom nn.Module |
Custom tf.keras.layers.Layer |
Sections
| Section | Topic | Key experiment |
|---|---|---|
| 1 | BatchNorm custom layer | training= flag in GradientTape loop |
| 2 | LayerNorm & RMSNorm | Throughput benchmark on sequences |
| 3 | SpectralNormalization | Lipschitz verification; parity with PyTorch |
| 4 | GroupNorm batch-size stability | BN vs GN validation curves |
| 5 | AdaIN style transfer | Channel stats verification |
| 6 | FiLM conditioning | Class-conditional generation |
Section 1 — BatchNorm as a Custom Keras Layer
Implement MyBatchNorm subclassing tf.keras.layers.Layer. The critical TF-specific pattern:
def call(self, x, training=False):
if training:
mean, var = tf.nn.moments(x, axes=[0])
# update running stats
self.running_mean.assign(0.9 * self.running_mean + 0.1 * mean)
self.running_var.assign( 0.9 * self.running_var + 0.1 * var )
else:
mean, var = self.running_mean, self.running_var
return tf.nn.batch_normalization(x, mean, var, self.beta, self.gamma, self.eps)
In model.fit, Keras propagates the training flag automatically. In a custom GradientTape loop you must pass it explicitly — forgetting to do so is the TF equivalent of forgetting model.eval() in PyTorch.
Section 2 — LayerNorm and RMSNorm
Implement MyRMSNorm as a Keras layer normalizing over axis=-1. There is no tf.keras.layers.RMSNorm built-in — this is a common gap in TF vs PyTorch parity. Verify numerical agreement with a PyTorch reference on the same tensors (saved to disk).
Benchmark both on (8, 2048, 4096) tensors using tf.function compilation. RMSNorm should show ~15% throughput improvement over LayerNorm.
Section 3 — SpectralNormalization
tf.keras.layers.SpectralNormalization(layer) wraps any layer and divides its kernel by the spectral norm at each step. Apply it to every Dense layer in a discriminator-style MLP:
disc = tf.keras.Sequential([
tf.keras.layers.SpectralNormalization(tf.keras.layers.Dense(256, activation='relu')),
tf.keras.layers.SpectralNormalization(tf.keras.layers.Dense(256, activation='relu')),
tf.keras.layers.SpectralNormalization(tf.keras.layers.Dense(1))
])
After training, extract each layer's kernel and verify tf.linalg.svd(kernel)[0][0] ≤ 1.0 + 1e-4.
Section 4 — GroupNorm Stability
Reproduce the batch-size experiment from Lab 1 in TF:
tf.keras.layers.GroupNormalization(groups=32) # stable at all batch sizes
tf.keras.layers.BatchNormalization() # degrades below ~8
Train the same detection head architecture for 5 epochs at batch sizes {1, 2, 4, 8, 16, 64}. Plot validation loss for both normalizations; confirm TF results match PyTorch Lab 1.
Section 5 — AdaIN Style Transfer
Implement AdaIN as a Keras layer that takes (content, style) as inputs. TF uses channels-last (B, H, W, C) by default:
class AdaIN(tf.keras.layers.Layer):
def call(self, content, style):
c_mean = tf.reduce_mean(content, axis=[1, 2], keepdims=True)
c_std = tf.math.reduce_std(content, axis=[1, 2], keepdims=True) + 1e-5
s_mean = tf.reduce_mean(style, axis=[1, 2], keepdims=True)
s_std = tf.math.reduce_std(style, axis=[1, 2], keepdims=True) + 1e-5
return s_std * (content - c_mean) / c_std + s_mean
Note the axis difference: PyTorch uses dim=[2, 3] (channels-first), TF uses axis=[1, 2] (channels-last).
Section 6 — FiLM Conditioning
FiLM (Feature-wise Linear Modulation) predicts per-channel and from a conditioning vector :
class FiLM(tf.keras.layers.Layer):
def __init__(self, num_features):
super().__init__()
self.gamma_proj = tf.keras.layers.Dense(num_features)
self.beta_proj = tf.keras.layers.Dense(num_features)
def call(self, x, condition):
gamma = self.gamma_proj(condition)[:, None, None, :] # (B,1,1,C)
beta = self.beta_proj(condition)[:, None, None, :]
return gamma * x + beta
Apply FiLM conditioning to a small class-conditional image generator: the conditioning vector is a learned class embedding. Train on CIFAR-10 and verify that the model generates class-specific features by interpolating between two class embeddings.