Neural Network Architectures in TensorFlow
Lab Overview
TensorFlow/Keras companion to the PyTorch lab. Re-implements each architecture as custom Keras layers and models, with emphasis on the Keras subclassing API, training= flag propagation, and GradientTape for gradient inspection.
Key API Differences vs PyTorch
| Concept | PyTorch | TensorFlow / Keras |
|---|---|---|
| Model definition | nn.Module + forward |
tf.keras.layers.Layer + call |
| Sequential model | nn.Sequential |
tf.keras.Sequential |
| Gradient computation | loss.backward() |
tf.GradientTape |
| Gradient access | param.grad |
tape.gradient(loss, vars) |
| Data channels order | NCHW (default) | NHWC (default) |
| Causal mask | Upper-triangular bool tensor | Same pattern, tf.linalg.band_part |
Sections
| Section | Topic | Key experiment |
|---|---|---|
| 1 | MLP: depth and non-linearity | Rank collapse without activations |
| 2 | CNN and depthwise separable | Parameter counts; CIFAR-10 accuracy |
| 3 | ResNet residual connections | GradientTape gradient-flow visualization |
| 4 | Vanilla RNN vs. LSTM | Long-sequence task; gradient norms |
| 5 | Scaled dot-product attention | Saturation experiment; MHA module |
| 6 | Mini transformer | Character LM; per-head attention heatmaps |
Section 1 — MLP with Keras Subclassing
Implement MLP as a tf.keras.layers.Layer with configurable depth:
class MLP(tf.keras.layers.Layer):
def __init__(self, hidden_dims, d_out, activation='relu'):
super().__init__()
self.hidden = [tf.keras.layers.Dense(d, activation=activation)
for d in hidden_dims]
self.out = tf.keras.layers.Dense(d_out)
def call(self, x):
for layer in self.hidden:
x = layer(x)
return self.out(x)
Rank collapse experiment: build a 4-layer linear MLP (activation=None) and a 4-layer ReLU MLP. After a forward pass, compute tf.linalg.matrix_rank on the output. The linear MLP's output rank should equal min(rank(W₁W₂W₃W₄), d_in) — not the hidden dimension.
Section 2 — CNN and Depthwise Separable Convolutions
TF/Keras uses NHWC layout by default. Implement depthwise separable convolution:
class DepthwiseSeparable(tf.keras.layers.Layer):
def __init__(self, out_ch, K=3):
super().__init__()
self.dw = tf.keras.layers.DepthwiseConv2D(K, padding='same')
self.pw = tf.keras.layers.Conv2D(out_ch, 1)
def call(self, x):
return self.pw(self.dw(x))
Note: TF provides tf.keras.layers.DepthwiseConv2D as a built-in. Parameter count comparison: for in_ch=64, out_ch=128, K=3, verify the ~8× reduction matches the PyTorch Lab results.
Section 3 — ResNet with GradientTape
Implement BasicBlock as a Keras layer and build a 20-layer plain vs. ResNet model. Use GradientTape to inspect per-layer gradient norms:
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss = loss_fn(y, logits)
grads = tape.gradient(loss, model.trainable_variables)
for var, grad in zip(model.trainable_variables, grads):
if 'conv' in var.name:
print(var.name, tf.norm(grad).numpy())
Key difference from PyTorch: you must explicitly request gradients via GradientTape rather than calling .backward(). Verify that ResNet gradients remain larger in early blocks compared to the plain 20-layer CNN.
Section 4 — Vanilla RNN and LSTM from Scratch
Implement a minimal RNN cell and LSTM cell as Keras layers:
class VanillaRNNCell(tf.keras.layers.Layer):
def __init__(self, d_h):
super().__init__()
self.Wh = tf.keras.layers.Dense(d_h, use_bias=False)
self.Wx = tf.keras.layers.Dense(d_h)
def call(self, x_t, h_prev):
return tf.tanh(self.Wh(h_prev) + self.Wx(x_t))
Wrap cells in a loop over timesteps and compare against tf.keras.layers.SimpleRNN and tf.keras.layers.LSTM. Run the long-sequence synthetic task (sequence length 200, label at position 0) and compare accuracy. For gradient analysis, wrap the full forward pass in GradientTape and inspect gradients at each time step manually.
Section 5 — Scaled Dot-Product Attention
Implement attention with optional scaling:
class ScaledDotProductAttention(tf.keras.layers.Layer):
def call(self, Q, K, V, mask=None, scale=True):
d_k = tf.cast(tf.shape(K)[-1], tf.float32)
scores = tf.matmul(Q, K, transpose_b=True)
if scale:
scores /= tf.math.sqrt(d_k)
if mask is not None:
scores += mask * -1e9
weights = tf.nn.softmax(scores, axis=-1)
return tf.matmul(weights, V), weights
Causal mask in TF: tf.linalg.band_part(tf.ones((T, T)), -1, 0) creates a lower-triangular matrix; subtract 1 and multiply by 1e9 to get the additive mask.
Saturation experiment: verify that removing the scale=True flag causes softmax entropy to drop as d_k increases — replicate the PyTorch Lab plot using TF operations.
Build MultiHeadAttention as a Keras layer and compare numerically against tf.keras.layers.MultiHeadAttention on the same input tensors.
Section 6 — Mini Decoder-Only Transformer
Assemble the transformer using custom Keras layers:
class TransformerBlock(tf.keras.layers.Layer):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.norm1 = tf.keras.layers.LayerNormalization(axis=-1)
self.norm2 = tf.keras.layers.LayerNormalization(axis=-1)
self.attn = MultiHeadAttention(d_model, n_heads)
self.ffn = tf.keras.Sequential([
tf.keras.layers.Dense(d_ff, activation='gelu'),
tf.keras.layers.Dense(d_model)
])
self.drop = tf.keras.layers.Dropout(dropout)
def call(self, x, causal_mask, training=False):
x = x + self.drop(self.attn(self.norm1(x), causal_mask), training=training)
x = x + self.drop(self.ffn(self.norm2(x)), training=training)
return x
Key TF-specific pattern: the training= flag must be passed explicitly through each sub-layer when using custom GradientTape training loops. Keras model.fit propagates it automatically.
Train on a character-level text dataset for 5 epochs with context length 256. After training, extract attention weight tensors for each head and visualize them as heatmaps — compare head specialization patterns against the PyTorch Lab results to verify cross-framework parity.