Einsum in PyTorch
torch.einsum strings and verify correctness numerically.
torch.bmm output.
Lab: Einstein Summation in PyTorch
torch.einsum is a single function that can express virtually any multilinear operation on tensors — dot products, matrix multiplies, transposes, traces, outer products, and batched variants — using a compact notation borrowed directly from physics and mathematics.
This lab builds your fluency with the notation from the ground up, then shows how the same patterns appear in modern ML architectures.
What You'll Build
- A reference sheet of the 5 canonical einsum families: unary, binary, batched, quadratic, and attention
- Verified implementations of every standard matrix operation from the prereq readings, translated into einsum strings
- A from-scratch scaled dot-product attention forward pass using only
torch.einsum - A batch covariance routine that mirrors the 3DGS covariance construction
- A performance benchmark comparing einsum against
torch.bmmandtorch.matmul
Key Concepts Practiced
After this lab you will be able to read any einsum string in a research codebase and immediately parse which axes are being contracted, which are being broadcast, and what the output shape will be — and write your own for novel tensor operations without resorting to explicit reshape/transpose/matmul chains.