Supplement · Neural Network Architectures

Neural Network Architectures in TensorFlow

Colab Notebook · ~55 min
Google Colab Notebook
Neural Network Architectures in TensorFlow
Python · ~55 min
Open in Colab
Lab Objectives
1
Build a configurable MLP using Keras functional and subclassing APIs; verify the non-linearity collapse using matrix rank analysis and compare parameter counts to theoretical formulas
2
Implement custom Conv2D and depthwise separable layers as Keras layers; verify output-size formula and benchmark parameter reduction vs. standard convolutions on CIFAR-10
3
Build ResNet basic and bottleneck blocks as custom Keras layers; reproduce the gradient-flow experiment using GradientTape and compare early-layer gradient magnitudes with and without residual connections
4
Implement a vanilla RNN and an LSTM cell from scratch using tf.keras.layers.Layer; compare performance on a long-sequence synthetic task and inspect gradient norms via GradientTape
5
Implement scaled dot-product attention and multi-head attention as custom Keras layers; verify softmax saturation at large d_k when the √d_k scaling is removed
6
Assemble a mini decoder-only transformer and train it on character-level language modeling; visualize per-head attention patterns using matplotlib

Lab Overview

TensorFlow/Keras companion to the PyTorch lab. Re-implements each architecture as custom Keras layers and models, with emphasis on the Keras subclassing API, training= flag propagation, and GradientTape for gradient inspection.

Key API Differences vs PyTorch

Concept PyTorch TensorFlow / Keras
Model definition nn.Module + forward tf.keras.layers.Layer + call
Sequential model nn.Sequential tf.keras.Sequential
Gradient computation loss.backward() tf.GradientTape
Gradient access param.grad tape.gradient(loss, vars)
Data channels order NCHW (default) NHWC (default)
Causal mask Upper-triangular bool tensor Same pattern, tf.linalg.band_part

Sections

Section Topic Key experiment
1 MLP: depth and non-linearity Rank collapse without activations
2 CNN and depthwise separable Parameter counts; CIFAR-10 accuracy
3 ResNet residual connections GradientTape gradient-flow visualization
4 Vanilla RNN vs. LSTM Long-sequence task; gradient norms
5 Scaled dot-product attention Saturation experiment; MHA module
6 Mini transformer Character LM; per-head attention heatmaps

Section 1 — MLP with Keras Subclassing

Implement MLP as a tf.keras.layers.Layer with configurable depth:

class MLP(tf.keras.layers.Layer):
    def __init__(self, hidden_dims, d_out, activation='relu'):
        super().__init__()
        self.hidden = [tf.keras.layers.Dense(d, activation=activation)
                       for d in hidden_dims]
        self.out    = tf.keras.layers.Dense(d_out)

    def call(self, x):
        for layer in self.hidden:
            x = layer(x)
        return self.out(x)

Rank collapse experiment: build a 4-layer linear MLP (activation=None) and a 4-layer ReLU MLP. After a forward pass, compute tf.linalg.matrix_rank on the output. The linear MLP's output rank should equal min(rank(W₁W₂W₃W₄), d_in) — not the hidden dimension.

Section 2 — CNN and Depthwise Separable Convolutions

TF/Keras uses NHWC layout by default. Implement depthwise separable convolution:

class DepthwiseSeparable(tf.keras.layers.Layer):
    def __init__(self, out_ch, K=3):
        super().__init__()
        self.dw = tf.keras.layers.DepthwiseConv2D(K, padding='same')
        self.pw = tf.keras.layers.Conv2D(out_ch, 1)

    def call(self, x):
        return self.pw(self.dw(x))

Note: TF provides tf.keras.layers.DepthwiseConv2D as a built-in. Parameter count comparison: for in_ch=64, out_ch=128, K=3, verify the ~8× reduction matches the PyTorch Lab results.

Section 3 — ResNet with GradientTape

Implement BasicBlock as a Keras layer and build a 20-layer plain vs. ResNet model. Use GradientTape to inspect per-layer gradient norms:

with tf.GradientTape() as tape:
    logits = model(x, training=True)
    loss   = loss_fn(y, logits)

grads = tape.gradient(loss, model.trainable_variables)
for var, grad in zip(model.trainable_variables, grads):
    if 'conv' in var.name:
        print(var.name, tf.norm(grad).numpy())

Key difference from PyTorch: you must explicitly request gradients via GradientTape rather than calling .backward(). Verify that ResNet gradients remain larger in early blocks compared to the plain 20-layer CNN.

Section 4 — Vanilla RNN and LSTM from Scratch

Implement a minimal RNN cell and LSTM cell as Keras layers:

class VanillaRNNCell(tf.keras.layers.Layer):
    def __init__(self, d_h):
        super().__init__()
        self.Wh = tf.keras.layers.Dense(d_h, use_bias=False)
        self.Wx = tf.keras.layers.Dense(d_h)

    def call(self, x_t, h_prev):
        return tf.tanh(self.Wh(h_prev) + self.Wx(x_t))

Wrap cells in a loop over timesteps and compare against tf.keras.layers.SimpleRNN and tf.keras.layers.LSTM. Run the long-sequence synthetic task (sequence length 200, label at position 0) and compare accuracy. For gradient analysis, wrap the full forward pass in GradientTape and inspect gradients at each time step manually.

Section 5 — Scaled Dot-Product Attention

Implement attention with optional scaling:

class ScaledDotProductAttention(tf.keras.layers.Layer):
    def call(self, Q, K, V, mask=None, scale=True):
        d_k = tf.cast(tf.shape(K)[-1], tf.float32)
        scores = tf.matmul(Q, K, transpose_b=True)
        if scale:
            scores /= tf.math.sqrt(d_k)
        if mask is not None:
            scores += mask * -1e9
        weights = tf.nn.softmax(scores, axis=-1)
        return tf.matmul(weights, V), weights

Causal mask in TF: tf.linalg.band_part(tf.ones((T, T)), -1, 0) creates a lower-triangular matrix; subtract 1 and multiply by 1e9 to get the additive mask.

Saturation experiment: verify that removing the scale=True flag causes softmax entropy to drop as d_k increases — replicate the PyTorch Lab plot using TF operations.

Build MultiHeadAttention as a Keras layer and compare numerically against tf.keras.layers.MultiHeadAttention on the same input tensors.

Section 6 — Mini Decoder-Only Transformer

Assemble the transformer using custom Keras layers:

class TransformerBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = tf.keras.layers.LayerNormalization(axis=-1)
        self.norm2 = tf.keras.layers.LayerNormalization(axis=-1)
        self.attn  = MultiHeadAttention(d_model, n_heads)
        self.ffn   = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation='gelu'),
            tf.keras.layers.Dense(d_model)
        ])
        self.drop  = tf.keras.layers.Dropout(dropout)

    def call(self, x, causal_mask, training=False):
        x = x + self.drop(self.attn(self.norm1(x), causal_mask), training=training)
        x = x + self.drop(self.ffn(self.norm2(x)), training=training)
        return x

Key TF-specific pattern: the training= flag must be passed explicitly through each sub-layer when using custom GradientTape training loops. Keras model.fit propagates it automatically.

Train on a character-level text dataset for 5 epochs with context length 256. After training, extract attention weight tensors for each head and visualize them as heatmaps — compare head specialization patterns against the PyTorch Lab results to verify cross-framework parity.