Prerequisite · Calculus Foundations

The Chain Rule

15 min read
0:00
Audio overview generated with
By the end of this reading you will be able to:
  • Apply the chain rule to differentiate a composite function f(g(x)), identifying the outer and inner functions
  • Extend the chain rule to chains of three or more composed functions
  • Connect the chain rule to backpropagation: explain how gradients flow backward through a neural network as a repeated application of the chain rule

Composed Functions

Most functions we care about are not simple polynomials — they are compositions: one function applied inside another. For example:

  • h(x)=(x2+1)10h(x) = (x^2 + 1)^{10} — raise a polynomial to a power
  • h(x)=sin(3x2)h(x) = \sin(3x^2) — apply sine to a polynomial
  • h(x)=x2+1h(x) = \sqrt{x^2 + 1} — take the square root of something

In each case, h(x)=f(g(x))h(x) = f(g(x)) for some outer function ff and inner function gg. The rules from r2 do not handle this — we need the chain rule.


The Chain Rule

If h(x)=f(g(x))h(x) = f(g(x)), then:

h(x)=f(g(x))g(x)h'(x) = f'(g(x)) \cdot g'(x)

In words: differentiate the outside, leave the inside alone, then multiply by the derivative of the inside.

In Leibniz notation, let u=g(x)u = g(x) and y=f(u)y = f(u). Then:

dydx=dydududx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}

This looks like the dudu's cancel — they do not literally cancel (they are not fractions), but the intuition is sound and the notation makes the rule easy to remember.


Worked Examples

Example 1: h(x)=(x2+1)10h(x) = (x^2 + 1)^{10}

Outer function: f(u)=u10f(u) = u^{10}, so f(u)=10u9f'(u) = 10u^9 Inner function: g(x)=x2+1g(x) = x^2 + 1, so g(x)=2xg'(x) = 2x

h(x)=10(x2+1)92x=20x(x2+1)9h'(x) = 10(x^2+1)^9 \cdot 2x = 20x(x^2+1)^9

Example 2: h(x)=sin(3x2)h(x) = \sin(3x^2)

Outer: f(u)=sinuf(u) = \sin u, so f(u)=cosuf'(u) = \cos u Inner: g(x)=3x2g(x) = 3x^2, so g(x)=6xg'(x) = 6x

h(x)=cos(3x2)6x=6xcos(3x2)h'(x) = \cos(3x^2) \cdot 6x = 6x\cos(3x^2)

Example 3: h(x)=x2+1=(x2+1)1/2h(x) = \sqrt{x^2 + 1} = (x^2+1)^{1/2}

Outer: f(u)=u1/2f(u) = u^{1/2}, so f(u)=12u1/2f'(u) = \frac{1}{2}u^{-1/2} Inner: g(x)=x2+1g(x) = x^2+1, so g(x)=2xg'(x) = 2x

h(x)=12(x2+1)1/22x=xx2+1h'(x) = \frac{1}{2}(x^2+1)^{-1/2} \cdot 2x = \frac{x}{\sqrt{x^2+1}}


Chains of Three or More

The chain rule extends to any number of composed functions. For h(x)=f(g(k(x)))h(x) = f(g(k(x))):

h(x)=f(g(k(x)))g(k(x))k(x)h'(x) = f'(g(k(x))) \cdot g'(k(x)) \cdot k'(x)

In Leibniz notation, with v=k(x)v = k(x), u=g(v)u = g(v), y=f(u)y = f(u):

dydx=dydududvdvdx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dv} \cdot \frac{dv}{dx}

Each link in the chain contributes a multiplicative factor. The gradient of the output with respect to the input is the product of all the local derivatives along the path.


The Chain Rule and Backpropagation

A feedforward neural network is literally a chain of composed functions. For a simple two-layer network:

y^=f2 ⁣(W2f1 ⁣(W1x+b1)+b2)\hat{y} = f_2\!\left(W_2\, f_1\!\left(W_1 x + b_1\right) + b_2\right)

where f1f_1 and f2f_2 are activation functions. This is h(x)=f2(g(f1(k(x))))h(x) = f_2(g(f_1(k(x)))) — a composition.

To train the network, we need L/W1\partial \mathcal{L} / \partial W_1 — how does the loss at the output depend on the weights deep in the network? By the chain rule:

LW1=Ly^y^z2z2a1a1z1z1W1\frac{\partial \mathcal{L}}{\partial W_1} = \frac{\partial \mathcal{L}}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial z_2} \cdot \frac{\partial z_2}{\partial a_1} \cdot \frac{\partial a_1}{\partial z_1} \cdot \frac{\partial z_1}{\partial W_1}

Backpropagation is exactly this: traverse the computation graph from output to input, accumulating the product of local derivatives at each step. The chain rule is not merely related to backprop — it is backprop.


A Concrete ML Example: Differentiating Through a Neuron

Consider a single neuron with weight ww, input xx, bias bb, and sigmoid activation σ(z)=1/(1+ez)\sigma(z) = 1/(1 + e^{-z}):

a=σ(wx+b)a = \sigma(wx + b)

The loss is the squared error L=(ay)2\mathcal{L} = (a - y)^2 for target yy. We want L/w\partial \mathcal{L}/\partial w.

By the chain rule: Lw=Laazzw\frac{\partial \mathcal{L}}{\partial w} = \frac{\partial \mathcal{L}}{\partial a} \cdot \frac{\partial a}{\partial z} \cdot \frac{\partial z}{\partial w}

where z=wx+bz = wx + b.

  • L/a=2(ay)\partial \mathcal{L}/\partial a = 2(a - y) (power rule on the loss)
  • a/z=σ(z)(1σ(z))\partial a/\partial z = \sigma(z)(1 - \sigma(z)) (sigmoid derivative — derived in r4)
  • z/w=x\partial z/\partial w = x (linear in ww)

Multiplying: L/w=2(ay)σ(z)(1σ(z))x\partial \mathcal{L}/\partial w = 2(a - y) \cdot \sigma(z)(1-\sigma(z)) \cdot x

Each factor has a natural interpretation: how wrong the prediction is, how sensitive the activation is, and how strongly the input influenced the pre-activation. This three-way product structure appears throughout deep learning.


PyTorch and TensorFlow

Autograd builds the computation graph during the forward pass and applies the chain rule backward when .backward() is called. You can inspect individual gradients:

import torch

w = torch.tensor(0.5, requires_grad=True)
x_val = torch.tensor(2.0)
b = torch.tensor(0.0)
y = torch.tensor(1.0)

z = w * x_val + b
a = torch.sigmoid(z)
loss = (a - y) ** 2

loss.backward()
print(w.grad.item())
# PyTorch computed dL/dw = 2(a-y) · σ(z)(1-σ(z)) · x automatically
import tensorflow as tf

w = tf.Variable(0.5)
x_val = tf.constant(2.0)
y = tf.constant(1.0)

with tf.GradientTape() as tape:
    z = w * x_val
    a = tf.sigmoid(z)
    loss = (a - y) ** 2

print(tape.gradient(loss, w).numpy())  # same result via chain rule

The framework never needs to know the analytic form of L/w\partial \mathcal{L}/\partial w. It computes it by walking backward through the recorded operations and applying the chain rule at each step.