Supplement · Loss Functions

Metric Learning: Triplet Losses

10 min read
By the end of this reading you will be able to:
  • State the triplet constraint: the anchor-positive distance must be smaller than the anchor-negative distance by at least margin alpha
  • Apply TripletMarginLoss with a chosen p-norm and margin, and verify that a satisfied constraint contributes zero loss
  • Use TripletMarginWithDistanceLoss to plug in a custom distance function such as cosine distance
  • Explain easy, hard, and semi-hard triplet mining and why random triplet selection leads to collapsed training signal
  • Compare triplet loss, contrastive loss, and NT-Xent in terms of the number of negatives used per update and training efficiency

What Is Metric Learning?

The goal of metric learning is to train an embedding function fθ:XRdf_\theta: \mathcal{X} \to \mathbb{R}^d such that semantically similar inputs map to nearby points in the embedding space, and dissimilar inputs map far apart.

Unlike classification losses, metric learning losses operate on triplets or pairs of samples — they encode relative rather than absolute correctness.


The Triplet Constraint

A triplet consists of:

  • Anchor aa: the reference sample
  • Positive pp: a sample from the same class/cluster as aa
  • Negative nn: a sample from a different class/cluster

The embedding space should satisfy d(a,p)<d(a,n)d(a, p) < d(a, n) — anchor is closer to positive than to negative. The triplet loss enforces this with a margin mm:

(a,p,n)=max(0,  d(a,p)d(a,n)+m)\ell(a, p, n) = \max\bigl(0, \; d(a, p) - d(a, n) + m\bigr)

The loss is zero when the anchor-negative distance exceeds the anchor-positive distance by at least mm. Otherwise it penalises the shortfall linearly.


nn.TripletMarginLoss

Default distance is LpL^p norm: d(u,v)=uvpd(u, v) = \|u - v\|_p.

triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)  # Euclidean distance

anchor   = torch.randn(32, 128, requires_grad=True)  # 32 samples, 128-dim
positive = torch.randn(32, 128, requires_grad=True)
negative = torch.randn(32, 128, requires_grad=True)

loss = triplet_loss(anchor, positive, negative)
# loss > 0 when d(a,p) - d(a,n) + 1 > 0 for any triplet

The swap parameter enables the distance swap heuristic: replace d(a,n)d(a,n) with max(d(a,n),d(p,n))\max(d(a,n), d(p,n)), which tightens the constraint when the positive and negative are very close.

When to use: Face recognition (FaceNet); image retrieval (Pinterest, Google Images); few-shot learning (prototypical networks need similar training signal).


nn.TripletMarginWithDistanceLoss

Same triplet formulation but accepts any differentiable distance function:

(a,p,n)=max(0,  dcustom(a,p)dcustom(a,n)+m)\ell(a, p, n) = \max\bigl(0, \; d_{\text{custom}}(a, p) - d_{\text{custom}}(a, n) + m\bigr)

# Cosine distance: 1 - cosine_similarity
def cosine_dist(u, v):
    return 1.0 - F.cosine_similarity(u, v)

triplet_loss = nn.TripletMarginWithDistanceLoss(
    distance_function=cosine_dist,
    margin=0.2
)
loss = triplet_loss(anchor, positive, negative)

The custom distance function must accept two tensors of shape (N,d)(N, d) and return a tensor of shape (N,)(N,). It must be differentiable (autograd-compatible).

When to use: Sentence embeddings (cosine distance more natural than Euclidean); normalised embeddings on a unit hypersphere; any domain where L2L^2 distance is inappropriate.


Triplet Mining

Naively sampling random triplets is inefficient — most will produce zero loss because random negatives are already far from the anchor. In practice, triplets are mined:

Strategy Description Convergence
Easy triplets d(a,n)>d(a,p)+md(a,n) > d(a,p) + m — already satisfied No gradient; skip
Semi-hard triplets d(a,p)<d(a,n)<d(a,p)+md(a,p) < d(a,n) < d(a,p) + m Moderate; recommended
Hard triplets d(a,n)<d(a,p)d(a,n) < d(a,p) — negative is closer than positive Large gradient; can destabilise early training

FaceNet found semi-hard mining the most effective for stable training.


Comparison: Triplet vs. Contrastive vs. NT-Xent

Loss Inputs per step Key property
TripletMarginLoss Anchor, Positive, Negative Relative ordering; easy to interpret
HingeEmbeddingLoss Single distance + label Simpler; only binary similar/dissimilar
NT-Xent (SimCLR) Full batch, all pairs Uses all negatives in the batch; state-of-the-art for self-supervised

For supervised metric learning with labelled data, TripletMarginWithDistanceLoss with cosine distance is a strong baseline.