Supplement · Optimizers

Optimizers in TensorFlow

Colab Notebook · ~40 min
Google Colab Notebook
Optimizers in TensorFlow
Python · ~40 min
Open in Colab
Lab Objectives
1
Map each PyTorch optimizer to its TensorFlow/Keras equivalent
2
Subclass tf.keras.optimizers.Optimizer to implement RAdam with moment slots
3
Apply learning rate schedules using schedule objects, callbacks, and manual step
4
Implement linear warmup + cosine decay without a built-in combined scheduler
5
Verify numerical parity between PyTorch and TensorFlow Adam implementations
6
Compare optimizer convergence on Fashion-MNIST and interpret the results

Lab Overview

This lab covers TensorFlow/Keras optimizer equivalents, implements a custom RAdam optimizer subclass, and uses LR scheduling callbacks — then validates numerical parity with the PyTorch lab.

What You'll Build

  1. TF/Keras optimizer tour — map every PyTorch optimizer to its tf.keras.optimizers equivalent; note API differences (e.g., weight_decay is a constructor arg in Keras, not a gradient modification)
  2. Custom RAdam optimizer — subclass tf.keras.optimizers.Optimizer to implement RAdam from scratch using tf.Variable slots for moment buffers
  3. Learning rate scheduling in TF — compare three approaches:
    • tf.keras.optimizers.schedules.* (schedule objects passed to optimizer)
    • tf.keras.callbacks.ReduceLROnPlateau (metric-driven)
    • Custom LearningRateScheduler callback with Python function
  4. Cosine decay with warmup — implement tf.keras.optimizers.schedules.CosineDecayRestarts and compare with a manual linear warmup + cosine decay schedule
  5. Parity check vs. PyTorch — run identical synthetic regression tasks with Adam (lr=1e-3, betas=(0.9, 0.999)) in both frameworks; verify final loss values match within 1e-4
  6. Fashion-MNIST optimizer comparison — train the same ConvNet with SGD+momentum, Adam, AdamW, and RMSprop; plot training curves and report test accuracy

API Differences to Know

Feature PyTorch TensorFlow/Keras
Weight decay weight_decay arg (AdamW) or L2 in Adam weight_decay constructor arg (TF 2.12+) or kernel_regularizer
Gradient clipping clip_grad_norm_() before step clipnorm / clipvalue in optimizer constructor
LR scheduling lr_scheduler.* wraps optimizer Schedule object passed as learning_rate arg
Step call optimizer.step() optimizer.apply_gradients() or model.fit()
State dict optimizer.state_dict() optimizer.get_weights() / set_weights()

Prerequisites

  • Familiarity with TensorFlow GradientTape and model.fit()
  • Completion of the PyTorch Optimizers lab is helpful but not required