Second-Order & Advanced Techniques
- Explain why L-BFGS requires a closure function and identify the model sizes and problem types where quasi-Newton methods are practical
- Implement layer-wise learning rate decay using param groups and explain why earlier layers receive smaller learning rates during fine-tuning
- Apply clip_grad_norm_ and clip_grad_value_ at the correct point in the training loop and explain the difference in their effect on gradient direction
- Identify when fused optimizer kernels improve throughput and explain why failing to save optimizer state causes momentum resets on training resumption
LBFGS — Limited-Memory BFGS
BFGS is a quasi-Newton method that approximates the inverse Hessian using gradient differences across iterations, enabling Newton-like steps without ever computing the full Hessian:
Full BFGS requires memory (an matrix for parameters). L-BFGS keeps only the last gradient/update vector pairs (default history_size=100) to approximate , reducing memory to .
The Closure Pattern
L-BFGS performs multiple function evaluations per step (line search). PyTorch implements this via a closure — a callable that recomputes the loss:
optimizer = torch.optim.LBFGS(
model.parameters(), lr=1, max_iter=20,
history_size=100, line_search_fn='strong_wolfe'
)
for x, y in dataloader:
def closure():
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
return loss
optimizer.step(closure)
TensorFlow: No L-BFGS in Keras. TensorFlow Probability provides it:
import tensorflow_probability as tfp
result = tfp.optimizer.lbfgs_minimize(
value_and_gradients_function,
initial_position=initial_params,
max_iterations=50,
)
When to use: LBFGS is practical for small to medium models (< ~1M parameters) where you want fast convergence to a precise minimum — physics simulations, scientific computing, small regression problems, neural ODE fitting. It does not scale to large batches or full ImageNet training.
Param Groups Deep-Dive
Param groups let you apply different hyperparameters to different parts of the model. Every optimizer has a param_groups list you can manipulate at runtime:
optimizer = torch.optim.AdamW([
{'params': model.backbone.parameters(), 'lr': 1e-4, 'weight_decay': 0.01},
{'params': model.head.parameters(), 'lr': 1e-3, 'weight_decay': 0.0},
], betas=(0.9, 0.999))
# Modify LR at runtime (e.g., after epoch)
for group in optimizer.param_groups:
group['lr'] *= 0.1
Layer-Wise Learning Rate Decay (LLRD)
LLRD multiplies each layer's LR by a decay factor as you move from the output toward the input. It's standard for fine-tuning pretrained models:
def get_llrd_groups(model, base_lr=1e-3, decay=0.9):
layers = list(model.children())
groups = []
for i, layer in enumerate(reversed(layers)):
lr = base_lr * (decay ** i)
groups.append({'params': layer.parameters(), 'lr': lr})
return groups
optimizer = torch.optim.AdamW(get_llrd_groups(model))
TensorFlow: Keras doesn't support per-variable LRs natively. Use a custom training loop with separate optimizers per layer group:
def get_llrd_optimizers(model, base_lr=1e-3, decay=0.9):
layers = list(model.layers)
optimizers = []
for i, layer in enumerate(reversed(layers)):
lr = base_lr * (decay ** i)
opt = tf.keras.optimizers.AdamW(learning_rate=lr, weight_decay=0.01)
optimizers.append((opt, layer.trainable_variables))
return optimizers
# In training loop:
with tf.GradientTape() as tape:
loss = loss_fn(model(x, training=True), y)
all_vars = model.trainable_variables
grads = tape.gradient(loss, all_vars)
for opt, vars_ in get_llrd_optimizers(model):
layer_grads = [grads[all_vars.index(v)] for v in vars_]
opt.apply_gradients(zip(layer_grads, vars_))
Gradient Clipping
Gradient clipping is applied between loss.backward() and optimizer.step().
clip_grad_norm_
Computes the global L2 norm of all gradients and rescales them so the norm equals max_norm:
loss.backward()
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
TensorFlow:
grads = tape.gradient(loss, model.trainable_variables)
grads, total_norm = tf.clip_by_global_norm(grads, clip_norm=1.0)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Or via constructor: tf.keras.optimizers.Adam(clipnorm=1.0)
The return value total_norm is useful for monitoring gradient health. A spike signals instability.
clip_grad_value_
Clips each gradient element independently to [-clip_value, clip_value]. Simpler but changes gradient direction (unlike norm clipping which preserves direction):
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)
Recommendation: Use clip_grad_norm_ (max_norm=1.0) for transformers and RNNs. It's standard in Hugging Face Trainer, PyTorch Lightning, and similar frameworks.
Fused Optimizers
PyTorch 2.x supports fused=True for Adam, AdamW, SGD, and RMSprop on CUDA. The fused kernel combines all update operations into a single GPU kernel, reducing kernel launch overhead and memory traffic:
optimizer = torch.optim.AdamW(
model.parameters(), lr=1e-3, fused=True # requires CUDA
)
TensorFlow: TF/XLA handles kernel fusion automatically via @tf.function with jit_compile=True. No manual fused=True flag is needed:
@tf.function(jit_compile=True)
def train_step(x, y):
with tf.GradientTape() as tape:
loss = loss_fn(model(x, training=True), y)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
Benchmarks show 10–30% wall-clock speedup on large models. Use foreach=True as a CPU-compatible fallback that vectorizes the update loop over all parameter tensors at once.
Optimizer State and Checkpointing
Forgetting to save the optimizer state is a common checkpointing mistake. Adam's moment buffers encode the entire training history — restoring just model weights but not optimizer state causes a momentum reset that destabilizes fine-tuning:
# Save
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.pt')
# Load
checkpoint = torch.load('checkpoint.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])