The Adam Family
- Derive Adam's bias correction terms and explain why dividing by (1 − β^t) is necessary when moment buffers are zero-initialized
- Distinguish AdamW from Adam by tracing how decoupled weight decay changes the update rule and restores uniform L2 regularization semantics
- Compare RAdam's rectification term to standard Adam and state the conditions under which RAdam falls back to SGD-like updates during early training
- Select among Adam, AdamW, Adamax, NAdam, RAdam, and SparseAdam based on gradient sparsity, weight decay requirements, and training stability needs
Adam
Adam (Kingma & Ba 2014) combines momentum (first moment) with RMSprop's adaptive learning rate (second moment), and critically adds bias correction to compensate for zero-initialization:
With , , the bias-correction denominator is , which is tiny at (≈ 0.001), inflating to a large value. This prevents exploding updates at the start of training when the second moment estimate is near zero.
optimizer = torch.optim.Adam(
model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0
)
TensorFlow:
optimizer = tf.keras.optimizers.Adam(
learning_rate=1e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-8
)
# AMSGrad: tf.keras.optimizers.Adam(..., amsgrad=True) (Keras 3 / TF 2.13+)
AMSGrad (amsgrad=True) uses the maximum of all past values instead of the current estimate, guaranteeing a non-increasing effective learning rate and improving convergence guarantees at the cost of slightly slower adaptation.
AdamW — Decoupled Weight Decay
In Adam, weight_decay adds to the gradient before the update — making the decay adaptive (scaled by ). This means parameters with large gradient variance receive less regularization, breaking the intended uniform shrinkage.
AdamW (Loshchilov & Hutter 2017) applies weight decay after the adaptive update, independently of the gradient:
This restores the original L2 semantics: every parameter is shrunk by the same fraction regardless of its gradient history. AdamW is the default choice for transformers, diffusion models, and most modern architectures.
optimizer = torch.optim.AdamW(
model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01 # typical range: 0.01 – 0.1
)
TensorFlow:
optimizer = tf.keras.optimizers.AdamW(
learning_rate=1e-3, weight_decay=0.01,
beta_1=0.9, beta_2=0.999, epsilon=1e-8
)
Adamax
Adamax is a variant of Adam that replaces the L2 norm in the second moment with the L∞ (max) norm:
Because the max operation is naturally bounded (no runaway accumulation), Adamax does not need bias correction on . It is more robust to large gradient spikes and can be useful for embedding layers with extreme gradient variance.
optimizer = torch.optim.Adamax(
model.parameters(), lr=2e-3, betas=(0.9, 0.999), eps=1e-8
)
TensorFlow:
optimizer = tf.keras.optimizers.Adamax(
learning_rate=2e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-8
)
NAdam — Nesterov Adam
NAdam (Dozat 2016) incorporates Nesterov momentum into Adam by using a lookahead first moment — the gradient contribution from the next step's momentum is included in the current update:
In practice, NAdam converges slightly faster than Adam on smooth objectives and is a low-risk drop-in replacement.
optimizer = torch.optim.NAdam(
model.parameters(), lr=2e-3, betas=(0.9, 0.999),
momentum_decay=0.004 # gradual momentum warmup
)
TensorFlow:
optimizer = tf.keras.optimizers.Nadam(
learning_rate=2e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-8
)
RAdam — Rectified Adam
RAdam (Liu et al. 2019) diagnoses Adam's instability in early training: the second moment estimate has high variance when is small, causing the effective learning rate to fluctuate wildly. RAdam computes an analytical rectification term based on the estimated variance of the second moment:
- If : variance is tractable — apply the rectified adaptive update (like Adam)
- Else: fall back to SGD with momentum (no adaptive scaling)
This gives RAdam automatic warmup-free behavior: the early steps are SGD-like and the adaptive phase engages once the second moment estimate stabilizes.
optimizer = torch.optim.RAdam(
model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8
)
TensorFlow: No built-in RAdam in core TF. Available via TF-Addons (tfa.optimizers.RectifiedAdam) or implement from the paper. In practice, AdamW with a short LinearLR warmup achieves similar early-training stability.
SparseAdam
Standard Adam updates the second moment for every parameter every step, even when a parameter's gradient is zero (e.g., an embedding row for a token not in the batch). SparseAdam applies a lazy update: only parameters with non-zero gradients are updated at each step.
# Only valid for parameters with sparse gradients (e.g., nn.Embedding)
optimizer = torch.optim.SparseAdam(
[{'params': model.embedding.parameters()}],
lr=1e-3, betas=(0.9, 0.999), eps=1e-8
)
TensorFlow: TF handles sparse gradients automatically — tf.keras.optimizers.Adam applies lazy updates for tf.IndexedSlices (the TF equivalent of sparse tensors). No separate SparseAdam class is needed.
Constraint: SparseAdam only supports sparse gradients. Dense parameters must use a separate optimizer (use multiple param groups with different optimizer types, or AdamW for everything).
Adam Family Summary
| Variant | Key difference | Default weight_decay |
|---|---|---|
| Adam | First + second moment with bias correction | 0 |
| AdamW | Decoupled weight decay | 0.01 |
| Adamax | L∞ second moment, robust to spikes | 0 |
| NAdam | Nesterov lookahead first moment | 0 |
| RAdam | Variance-rectified warmup, auto SGD fallback | 0 |
| SparseAdam | Lazy update for sparse gradients | n/a |