Supplement · Loss Functions

Loss Functions in TensorFlow

Colab Notebook · ~40 min
Google Colab Notebook
Loss Functions in TensorFlow
Python · ~40 min
Open in Colab
Lab Objectives
1
Implement the core loss families using tf.keras.losses and tf.keras.losses.Loss base class
2
Verify numerical parity between PyTorch and TensorFlow for all shared loss functions
3
Differentiate through custom loss functions using tf.GradientTape
4
Build a custom Keras loss class (subclassing tf.keras.losses.Loss) for TripletMarginLoss
5
Train a model end-to-end with model.compile(loss=...) using both built-in and custom losses

Lab Overview

This notebook is the TensorFlow companion to the PyTorch lab. For every loss implemented in PyTorch, you will find the TensorFlow/Keras equivalent — or implement it from scratch when no built-in exists.

TF vs PyTorch API Differences

Concept PyTorch TensorFlow/Keras
MSE nn.MSELoss() tf.keras.losses.MeanSquaredError()
BCE from logits nn.BCEWithLogitsLoss() tf.keras.losses.BinaryCrossentropy(from_logits=True)
Categorical CE nn.CrossEntropyLoss() tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
KL divergence nn.KLDivLoss(reduction='batchmean') tf.keras.losses.KLDivergence()
Custom loss nn.Module subclass tf.keras.losses.Loss subclass

Key TF-Specific Topics

  • from_logits=True: Always prefer this in Keras — it enables numerically stable softmax+CE fusion, exactly like PyTorch's BCEWithLogitsLoss/CrossEntropyLoss.
  • GradientTape through custom losses: Custom tf.keras.losses.Loss subclasses are automatically differentiable.
  • Sample weighting: model.fit(sample_weight=...) applies per-sample weights to the loss — the TF equivalent of reduction='none' + manual weighting in PyTorch.

Sections

Section Content
1 Regression losses: MSE, MAE, Huber
2 Binary classification: BinaryCrossentropy
3 Multi-class: SparseCategoricalCrossentropy, CategoricalCrossentropy
4 KL divergence and soft targets
5 Custom tf.keras.losses.Loss — from-scratch Huber
6 Custom triplet loss with GradientTape
7 Numerical parity check: TF vs PyTorch on all shared losses
8 End-to-end: training with model.compile(loss=...)