Supplement · Neural Network Architectures

LSTM and GRU — Gating Solutions to Long-Range Memory

18 min read
By the end of this reading you will be able to:
  • Trace the LSTM forward pass through all four gates (forget, input, output, and cell-gate) and explain the role of each gate in controlling information flow
  • Explain why the LSTM cell state is a gradient highway: trace the gradient path through the cell state and show that the forget gate avoids the repeated matrix multiplication that causes vanishing gradients in vanilla RNNs
  • Trace the GRU forward pass, identify the two gates (reset and update), and explain how the GRU achieves similar capability to the LSTM with fewer parameters
  • Compare LSTM and GRU on parameter count, training speed, and empirical performance, and state when each is typically preferred

The Core Idea: Gating

The vanilla RNN's failure on long sequences is architectural: the hidden state must be rewritten at every step by the same transformation, allowing no way to preserve information across many timesteps without distortion.

The solution is gating — learned, input-dependent switches that control what information to keep, what to discard, and what to pass through. Hochreiter & Schmidhuber (1997) introduced the Long Short-Term Memory (LSTM) as the first gated RNN. Cho et al. (2014) introduced the simpler Gated Recurrent Unit (GRU).


LSTM: Long Short-Term Memory

The LSTM maintains two state vectors at each timestep:

  • ctRdh\mathbf{c}_t \in \mathbb{R}^{d_h} — the cell state: the long-term memory, updated slowly
  • htRdh\mathbf{h}_t \in \mathbb{R}^{d_h} — the hidden state: the working memory, output at each step

The Four Gates

All four gates take the same input — the concatenation [ht1;xt][\mathbf{h}_{t-1}; \mathbf{x}_t] — and produce a dhd_h-dimensional vector:

ft=σ(Wf[ht1;xt]+bf)forget gate\mathbf{f}_t = \sigma(W_f [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_f) \qquad \text{forget gate} it=σ(Wi[ht1;xt]+bi)input gate\mathbf{i}_t = \sigma(W_i [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_i) \qquad \text{input gate} c~t=tanh(Wc[ht1;xt]+bc)cell gate (candidate)\tilde{\mathbf{c}}_t = \tanh(W_c [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_c) \qquad \text{cell gate (candidate)} ot=σ(Wo[ht1;xt]+bo)output gate\mathbf{o}_t = \sigma(W_o [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_o) \qquad \text{output gate}

Gate vectors take values in (0,1)(0, 1) via sigmoid σ\sigma. A value near 1 means "pass through"; near 0 means "block".

Cell State Update

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

  • ftct1\mathbf{f}_t \odot \mathbf{c}_{t-1}: forget how much of the previous cell state to retain (element-wise multiply by forget gate)
  • itc~t\mathbf{i}_t \odot \tilde{\mathbf{c}}_t: write new candidate information, gated by how much to accept

Hidden State Update

ht=ottanh(ct)\mathbf{h}_t = \mathbf{o}_t \odot \tanh(\mathbf{c}_t)

The output gate controls what portion of the cell state is exposed as the hidden state (the actual output of the LSTM).


Why the Cell State Is a Gradient Highway

The cell state update is:

ct=ftct1+itc~t\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t

The gradient of the loss with respect to ct1\mathbf{c}_{t-1} is:

Lct1=Lctft\frac{\partial \mathcal{L}}{\partial \mathbf{c}_{t-1}} = \frac{\partial \mathcal{L}}{\partial \mathbf{c}_t} \odot \mathbf{f}_t

No matrix multiplication. The gradient flows through the cell state by element-wise multiplication with the forget gate. If the forget gate is near 1 ("remember everything"), L/ct1L/ct\partial \mathcal{L}/\partial \mathbf{c}_{t-1} \approx \partial \mathcal{L}/\partial \mathbf{c}_t — the gradient passes through unchanged.

This is the cell state as a gradient highway — analogous to the residual connection in ResNets. Information (and gradients) can flow across hundreds of timesteps with minimal degradation.


GRU: Gated Recurrent Unit

The GRU (Cho et al., 2014) simplifies the LSTM by merging the cell state and hidden state into a single state vector, and using only two gates:

rt=σ(Wr[ht1;xt]+br)reset gate\mathbf{r}_t = \sigma(W_r [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_r) \qquad \text{reset gate} zt=σ(Wz[ht1;xt]+bz)update gate\mathbf{z}_t = \sigma(W_z [\mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_z) \qquad \text{update gate} h~t=tanh(Wh[rtht1;xt]+bh)candidate state\tilde{\mathbf{h}}_t = \tanh(W_h [\mathbf{r}_t \odot \mathbf{h}_{t-1}; \mathbf{x}_t] + \mathbf{b}_h) \qquad \text{candidate state} ht=(1zt)ht1+zth~t\mathbf{h}_t = (1 - \mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t

Gate Roles

Reset gate rt\mathbf{r}_t: Controls how much past state is used when computing the candidate. Near 0 → ignore past (write new); near 1 → condition on past (extend memory).

Update gate zt\mathbf{z}_t: Controls the interpolation between the old state ht1\mathbf{h}_{t-1} and the new candidate h~t\tilde{\mathbf{h}}_t. Near 0 → keep old state (long memory); near 1 → adopt candidate (short memory).

The final update ht=(1zt)ht1+zth~t\mathbf{h}_t = (1-\mathbf{z}_t) \odot \mathbf{h}_{t-1} + \mathbf{z}_t \odot \tilde{\mathbf{h}}_t is a linear interpolation — the same gradient highway structure as the LSTM cell state, but without the separate cell vector.


LSTM vs. GRU

LSTM GRU
State vectors 2 (h and c) 1 (h)
Gates 4 2
Parameters (hidden dhd_h) 4×4dh24 \times 4d_h^2 3×3dh23 \times 3d_h^2
Parameter ratio 16dh216d_h^2 9dh29d_h^2 (~44% fewer)
Typical performance Marginally better on long sequences Competitive; faster to train

When to prefer LSTM: Long sequences where fine-grained memory control matters (language modeling, music generation, protein sequences).

When to prefer GRU: When training speed and parameter efficiency matter more; smaller datasets where overfitting is a risk; comparable performance in practice for most tasks.

In most modern applications (2017+), both have been largely supplanted by transformers for NLP. However, they remain the preferred architecture for:

  • Time series forecasting with irregular intervals
  • Low-latency online sequence processing (streaming)
  • Structured state-space models and Mamba (which can be viewed as a selective SSM analogous to a GRU)

PyTorch and TensorFlow

PyTorchnn.LSTM and nn.GRU:

import torch
import torch.nn as nn

# LSTM: returns (output, (h_n, c_n))
lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2,
               batch_first=True, dropout=0.2)

x              = torch.randn(8, 20, 32)       # (B, T, input_size)
out, (h_n, c_n) = lstm(x)
# out:  (8, 20, 64)  — hidden state at every timestep
# h_n:  (2, 8, 64)   — final hidden state, one per layer
# c_n:  (2, 8, 64)   — final cell state, one per layer

# Passing initial hidden/cell state (e.g. for stateful inference)
h0 = torch.zeros(2, 8, 64)   # (num_layers, batch, hidden_size)
c0 = torch.zeros(2, 8, 64)
out, (h_n, c_n) = lstm(x, (h0, c0))

# GRU: same interface but only returns (output, h_n) — no cell state
gru       = nn.GRU(32, 64, num_layers=2, batch_first=True, dropout=0.2)
out, h_n  = gru(x)            # h_n: (2, 8, 64)

# Bidirectional LSTM for sequence labeling
bi_lstm = nn.LSTM(32, 64, batch_first=True, bidirectional=True)
out, _  = bi_lstm(x)          # out: (8, 20, 128)

# Sequence classifier: take the final timestep hidden state
classifier = nn.Sequential(bi_lstm, )   # use h_n for fixed-length classification
last_hidden = out[:, -1, :]             # (8, 128) — last timestep
logits      = nn.Linear(128, 10)(last_hidden)

TensorFlow / Keras:

import tensorflow as tf

# LSTM
lstm = tf.keras.layers.LSTM(units=64, return_sequences=True,
                              return_state=True, dropout=0.2)
x             = tf.random.normal((8, 20, 32))
out, h_n, c_n = lstm(x)      # out: (8,20,64)  h_n,c_n: (8,64)

# GRU — no cell state
gru      = tf.keras.layers.GRU(64, return_sequences=True, return_state=True)
out, h_n = gru(x)            # out: (8,20,64)  h_n: (8,64)

# Bidirectional LSTM
bi_lstm = tf.keras.layers.Bidirectional(
    tf.keras.layers.LSTM(64, return_sequences=True)
)                             # output: (8, 20, 128)