LSTM and GRU — Gating Solutions to Long-Range Memory
- Trace the LSTM forward pass through all four gates (forget, input, output, and cell-gate) and explain the role of each gate in controlling information flow
- Explain why the LSTM cell state is a gradient highway: trace the gradient path through the cell state and show that the forget gate avoids the repeated matrix multiplication that causes vanishing gradients in vanilla RNNs
- Trace the GRU forward pass, identify the two gates (reset and update), and explain how the GRU achieves similar capability to the LSTM with fewer parameters
- Compare LSTM and GRU on parameter count, training speed, and empirical performance, and state when each is typically preferred
The Core Idea: Gating
The vanilla RNN's failure on long sequences is architectural: the hidden state must be rewritten at every step by the same transformation, allowing no way to preserve information across many timesteps without distortion.
The solution is gating — learned, input-dependent switches that control what information to keep, what to discard, and what to pass through. Hochreiter & Schmidhuber (1997) introduced the Long Short-Term Memory (LSTM) as the first gated RNN. Cho et al. (2014) introduced the simpler Gated Recurrent Unit (GRU).
LSTM: Long Short-Term Memory
The LSTM maintains two state vectors at each timestep:
- — the cell state: the long-term memory, updated slowly
- — the hidden state: the working memory, output at each step
The Four Gates
All four gates take the same input — the concatenation — and produce a -dimensional vector:
Gate vectors take values in via sigmoid . A value near 1 means "pass through"; near 0 means "block".
Cell State Update
- : forget how much of the previous cell state to retain (element-wise multiply by forget gate)
- : write new candidate information, gated by how much to accept
Hidden State Update
The output gate controls what portion of the cell state is exposed as the hidden state (the actual output of the LSTM).
Why the Cell State Is a Gradient Highway
The cell state update is:
The gradient of the loss with respect to is:
No matrix multiplication. The gradient flows through the cell state by element-wise multiplication with the forget gate. If the forget gate is near 1 ("remember everything"), — the gradient passes through unchanged.
This is the cell state as a gradient highway — analogous to the residual connection in ResNets. Information (and gradients) can flow across hundreds of timesteps with minimal degradation.
GRU: Gated Recurrent Unit
The GRU (Cho et al., 2014) simplifies the LSTM by merging the cell state and hidden state into a single state vector, and using only two gates:
Gate Roles
Reset gate : Controls how much past state is used when computing the candidate. Near 0 → ignore past (write new); near 1 → condition on past (extend memory).
Update gate : Controls the interpolation between the old state and the new candidate . Near 0 → keep old state (long memory); near 1 → adopt candidate (short memory).
The final update is a linear interpolation — the same gradient highway structure as the LSTM cell state, but without the separate cell vector.
LSTM vs. GRU
| LSTM | GRU | |
|---|---|---|
| State vectors | 2 (h and c) | 1 (h) |
| Gates | 4 | 2 |
| Parameters (hidden ) | ||
| Parameter ratio | (~44% fewer) | |
| Typical performance | Marginally better on long sequences | Competitive; faster to train |
When to prefer LSTM: Long sequences where fine-grained memory control matters (language modeling, music generation, protein sequences).
When to prefer GRU: When training speed and parameter efficiency matter more; smaller datasets where overfitting is a risk; comparable performance in practice for most tasks.
In most modern applications (2017+), both have been largely supplanted by transformers for NLP. However, they remain the preferred architecture for:
- Time series forecasting with irregular intervals
- Low-latency online sequence processing (streaming)
- Structured state-space models and Mamba (which can be viewed as a selective SSM analogous to a GRU)
PyTorch and TensorFlow
PyTorch — nn.LSTM and nn.GRU:
import torch
import torch.nn as nn
# LSTM: returns (output, (h_n, c_n))
lstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2,
batch_first=True, dropout=0.2)
x = torch.randn(8, 20, 32) # (B, T, input_size)
out, (h_n, c_n) = lstm(x)
# out: (8, 20, 64) — hidden state at every timestep
# h_n: (2, 8, 64) — final hidden state, one per layer
# c_n: (2, 8, 64) — final cell state, one per layer
# Passing initial hidden/cell state (e.g. for stateful inference)
h0 = torch.zeros(2, 8, 64) # (num_layers, batch, hidden_size)
c0 = torch.zeros(2, 8, 64)
out, (h_n, c_n) = lstm(x, (h0, c0))
# GRU: same interface but only returns (output, h_n) — no cell state
gru = nn.GRU(32, 64, num_layers=2, batch_first=True, dropout=0.2)
out, h_n = gru(x) # h_n: (2, 8, 64)
# Bidirectional LSTM for sequence labeling
bi_lstm = nn.LSTM(32, 64, batch_first=True, bidirectional=True)
out, _ = bi_lstm(x) # out: (8, 20, 128)
# Sequence classifier: take the final timestep hidden state
classifier = nn.Sequential(bi_lstm, ) # use h_n for fixed-length classification
last_hidden = out[:, -1, :] # (8, 128) — last timestep
logits = nn.Linear(128, 10)(last_hidden)
TensorFlow / Keras:
import tensorflow as tf
# LSTM
lstm = tf.keras.layers.LSTM(units=64, return_sequences=True,
return_state=True, dropout=0.2)
x = tf.random.normal((8, 20, 32))
out, h_n, c_n = lstm(x) # out: (8,20,64) h_n,c_n: (8,64)
# GRU — no cell state
gru = tf.keras.layers.GRU(64, return_sequences=True, return_state=True)
out, h_n = gru(x) # out: (8,20,64) h_n: (8,64)
# Bidirectional LSTM
bi_lstm = tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(64, return_sequences=True)
) # output: (8, 20, 128)