Shrinkage & Threshold Functions
- Explain the shrinkage principle: setting small-magnitude inputs to zero while shifting larger inputs toward zero by a threshold
- Distinguish Hardshrink (hard threshold, discontinuous) from Softshrink (soft threshold, continuous) in terms of their effect on near-zero inputs
- Apply Softplus as a smooth differentiable approximation to ReLU and explain when continuity at zero matters for gradient-based optimisation
- Connect Softshrink to the proximal operator of the L1 norm and explain its role in sparse signal recovery
The Shrinkage Principle
Shrinkage functions selectively suppress activations based on magnitude. They are connected to sparsity-inducing regularization: Hardshrink corresponds to hard thresholding, Softshrink is the proximal operator of the L1 norm (appearing in LASSO), and Tanhshrink is a smooth alternative. These functions are most useful in sparse autoencoders, compressed sensing models, and signal denoising networks.
Hardshrink
Hardshrink zeroes out all values within and passes values outside this band through unchanged. The output range is — there is a gap of zeros around the origin.
Gradient: for , for . Non-differentiable at . This hard discontinuity can cause gradient issues, but the sparsity it induces is exact — activations either survive completely or are fully zeroed.
Use case: Sparse coding, denoising autoencoders where you want exactly zero activations for small features.
PyTorch:
x = torch.tensor([-2., -0.3, 0., 0.3, 2.])
print(nn.Hardshrink(lambd=0.5)(x)) # tensor([-2., 0., 0., 0., 2.])
print(F.hardshrink(x, lambd=0.5)) # identical
TensorFlow:
# No built-in Hardshrink; implement with tf.where
x = tf.constant([-2., -0.3, 0., 0.3, 2.])
lambd = 0.5
hardshrink = lambda x, l: tf.where(tf.abs(x) > l, x, tf.zeros_like(x))
print(hardshrink(x, lambd)) # [-2. 0. 0. 0. 2.]
Softshrink — Soft Thresholding
Softshrink is the proximal operator of the L1 norm: minimizing with respect to yields . This connects directly to LASSO regression and compressed sensing.
Gradient: for , 0 for . Non-differentiable at .
Difference from Hardshrink: Hardshrink passes values as-is when they exceed . Softshrink shifts them toward zero by . Softshrink is bias-corrected — the surviving values are smaller.
PyTorch:
x = torch.tensor([-2., -0.3, 0., 0.3, 2.])
print(nn.Softshrink(lambd=0.5)(x)) # tensor([-1.5, 0.0, 0.0, 0.0, 1.5])
# Values that survive are shifted toward zero by lambd
TensorFlow:
# No built-in Softshrink; implement as proximal operator of L1
x = tf.constant([-2., -0.3, 0., 0.3, 2.])
lambd = 0.5
softshrink = lambda x, l: tf.math.sign(x) * tf.nn.relu(tf.abs(x) - l)
print(softshrink(x, lambd)) # [-1.5 0. 0. 0. 1.5]
Tanhshrink
Tanhshrink computes the residual between the identity and Tanh. Since near the origin, for small — the function shrinks small values without a hard threshold. For large , , so .
Gradient:
Smooth everywhere; gradient is always in . No dead zones, but shrinks most at the origin.
PyTorch:
x = torch.tensor([-2., -1., 0., 1., 2.])
print(nn.Tanhshrink()(x)) # tensor([-1.0036, -0.2384, 0.0000, 0.2384, 1.0036])
# = x - tanh(x); smooth; no hard threshold
TensorFlow:
x = tf.constant([-2., -1., 0., 1., 2.])
tanhshrink = lambda x: x - tf.math.tanh(x)
print(tanhshrink(x)) # [-1.0036 -0.2384 0. 0.2384 1.0036]
Threshold
Threshold is the most general step function: values above the threshold pass through unchanged; values at or below are replaced with a specified constant value. Unlike Hardshrink, the replacement value can be anything (not necessarily 0).
Example: nn.Threshold(0.1, 20) replaces all with 20 — this is an unusual but valid setup for binary indicator logic.
Gradient: 1 for , 0 otherwise — always non-differentiable at the threshold.
PyTorch:
x = torch.tensor([-1., 0., 0.1, 0.5, 2.])
# Values <= 0.1 replaced with 0.0
print(nn.Threshold(threshold=0.1, value=0.0)(x)) # tensor([0.0, 0.0, 0.0, 0.5, 2.0])
TensorFlow:
x = tf.constant([-1., 0., 0.1, 0.5, 2.])
threshold, value = 0.1, 0.0
print(tf.where(x > threshold, x, tf.fill(tf.shape(x), value)))
# [0. 0. 0. 0.5 2. ]
Softplus — Smooth ReLU
Softplus is a smooth, always-positive approximation of ReLU. As , Softplus converges to ReLU. The default gives a smooth curve that crosses and grows linearly for large .
Always positive: Unlike ReLU, Softplus is strictly for all . This makes it ideal for network outputs that must be positive, such as:
- Variance parameters in VAEs:
var = F.softplus(log_var) - Scale parameters in normalizing flows
- Poisson rate parameters in count models
Note on threshold parameter: For (default 20), PyTorch falls back to the linear approximation to avoid overflow: for large .
Gradient: — the gradient of Softplus is exactly the sigmoid.
PyTorch:
x = torch.tensor([-2., -1., 0., 1., 2.])
print(nn.Softplus(beta=1)(x)) # tensor([0.1269, 0.3133, 0.6931, 1.3133, 2.1269])
# Always positive; approaches ReLU as beta → ∞
# Use for variance outputs: var = F.softplus(raw_var)
TensorFlow:
x = tf.constant([-2., -1., 0., 1., 2.])
print(tf.keras.activations.softplus(x)) # [0.1269 0.3133 0.6931 1.3133 2.1269]
print(tf.nn.softplus(x)) # identical
Comparison Table
| Function | Formula | Gradient at small | Smoothness |
|---|---|---|---|
| Hardshrink | if , else | 0 | Discontinuous |
| Softshrink | 0 | Discontinuous | |
| Tanhshrink | Smooth | ||
| Threshold | if , else | 0 | Discontinuous |
| Softplus | Smooth |