Supplement · Optimizers
Optimizers in TensorFlow
Google Colab Notebook
Optimizers in TensorFlow
Lab Objectives
1
Map each PyTorch optimizer to its TensorFlow/Keras equivalent
2
Subclass tf.keras.optimizers.Optimizer to implement RAdam with moment slots
3
Apply learning rate schedules using schedule objects, callbacks, and manual step
4
Implement linear warmup + cosine decay without a built-in combined scheduler
5
Verify numerical parity between PyTorch and TensorFlow Adam implementations
6
Compare optimizer convergence on Fashion-MNIST and interpret the results
Lab Overview
This lab covers TensorFlow/Keras optimizer equivalents, implements a custom RAdam optimizer subclass, and uses LR scheduling callbacks — then validates numerical parity with the PyTorch lab.
What You'll Build
- TF/Keras optimizer tour — map every PyTorch optimizer to its
tf.keras.optimizersequivalent; note API differences (e.g.,weight_decayis a constructor arg in Keras, not a gradient modification) - Custom RAdam optimizer — subclass
tf.keras.optimizers.Optimizerto implement RAdam from scratch usingtf.Variableslots for moment buffers - Learning rate scheduling in TF — compare three approaches:
tf.keras.optimizers.schedules.*(schedule objects passed to optimizer)tf.keras.callbacks.ReduceLROnPlateau(metric-driven)- Custom
LearningRateSchedulercallback with Python function
- Cosine decay with warmup — implement
tf.keras.optimizers.schedules.CosineDecayRestartsand compare with a manual linear warmup + cosine decay schedule - Parity check vs. PyTorch — run identical synthetic regression tasks with Adam (lr=1e-3, betas=(0.9, 0.999)) in both frameworks; verify final loss values match within 1e-4
- Fashion-MNIST optimizer comparison — train the same ConvNet with SGD+momentum, Adam, AdamW, and RMSprop; plot training curves and report test accuracy
API Differences to Know
| Feature | PyTorch | TensorFlow/Keras |
|---|---|---|
| Weight decay | weight_decay arg (AdamW) or L2 in Adam |
weight_decay constructor arg (TF 2.12+) or kernel_regularizer |
| Gradient clipping | clip_grad_norm_() before step |
clipnorm / clipvalue in optimizer constructor |
| LR scheduling | lr_scheduler.* wraps optimizer |
Schedule object passed as learning_rate arg |
| Step call | optimizer.step() |
optimizer.apply_gradients() or model.fit() |
| State dict | optimizer.state_dict() |
optimizer.get_weights() / set_weights() |
Prerequisites
- Familiarity with TensorFlow
GradientTapeandmodel.fit() - Completion of the PyTorch Optimizers lab is helpful but not required