Key Distributions for Machine Learning
- State the PMF or PDF, mean, and variance for the Bernoulli, Binomial, Categorical, and Multinomial distributions and identify the ML context where each appears
- Explain the Gaussian distribution's central role via the Central Limit Theorem and extend it to the multivariate case, identifying what the covariance matrix encodes
- Distinguish the Beta distribution (distribution over a single probability) from the Dirichlet distribution (distribution over a probability vector) and explain their role as conjugate priors
- Identify the appropriate distribution family for a given ML modeling assumption: binary outcomes, class probabilities, continuous observations, count data, or categorical mixing weights
Why Distributions Matter
Every probabilistic ML model is built from named distributions. Understanding which distribution captures which kind of uncertainty — and what its parameters control — is essential for reading papers, designing models, and debugging training.
Discrete Distributions
Bernoulli()
The simplest random variable: a single binary outcome.
- Parameters: — the probability of
- Mean:
- Variance: — maximized at
- ML uses: binary classification output, individual pixel in a binary image model, gate activation
Binomial(, )
Sum of independent Bernoulli() trials: the number of successes.
- Parameters: (number of trials), (success probability)
- Mean:
- Variance:
- Bernoulli is the special case
Categorical()
Generalization of Bernoulli to outcomes. The probability of outcome is .
- Parameters: probability vector with ,
- ML uses: multi-class classification (the softmax output defines ), sampling tokens from a language model, action selection in RL
Multinomial(, )
Generalization of Binomial: independent draws from a Categorical(). Records the count vector where is the number of times outcome occurred.
- ML uses: word count models (bag-of-words), topic models (Latent Dirichlet Allocation)
Continuous Distributions
Gaussian (Normal) —
The most important distribution in ML and statistics.
- Parameters: mean (location), variance (spread)
- Mean: — Variance:
Why it appears everywhere: The Central Limit Theorem states that the sum of many independent, finite-variance random variables converges to a Gaussian as the number grows. This explains why measurement errors, additive noise, and empirical averages are well-modeled by Gaussians.
Properties useful for ML:
- Closed under linear transformations:
- Sum of independent Gaussians is Gaussian:
- Maximally uncertain for a given mean and variance (maximum entropy property)
Multivariate Gaussian —
Extension to a vector :
- Parameters: mean vector , covariance matrix (symmetric, positive definite)
- The covariance matrix encodes the shape and orientation of the distribution's ellipsoidal contours
ML uses: Gaussian noise models, VAE latent prior , Gaussian process priors, 3D Gaussian Splatting (Gaussians are MVN with learnable covariance)
Beta and Dirichlet: Priors Over Probabilities
Beta(, )
The Beta distribution is a distribution over the interval — making it the natural prior over a probability parameter.
where is the normalizing constant.
- Parameters: — can be thought of as pseudo-counts of successes and failures
- Mean:
- Shape: is Uniform; is unimodal; is U-shaped
Conjugate prior for Bernoulli: If and you observe successes and failures, the posterior is — same family, just updated counts. This is the definition of a conjugate prior: the posterior has the same distributional form as the prior.
Dirichlet()
The Dirichlet is the multivariate generalization of the Beta: a distribution over the -dimensional probability simplex .
- Parameters: concentration vector with
- Mean:
- Symmetric case: for all ; small → sparse (corners of simplex); large → concentrated at center
Conjugate prior for Categorical/Multinomial: Observe counts ; posterior is .
ML uses: LDA (Latent Dirichlet Allocation) uses Dirichlet priors over topic distributions; Think Bayes Ch. 18 covers the Dirichlet-multinomial model in detail.
Quick Reference Table
| Distribution | Type | Parameters | Mean | Variance | ML use |
|---|---|---|---|---|---|
| Bernoulli() | Discrete | Binary classification | |||
| Categorical() | Discrete | on simplex | — | Multi-class output | |
| Gaussian() | Continuous | Regression, noise, VAE latent | |||
| MVN() | Continuous | mean + covariance | Gaussian processes, 3DGS | ||
| Beta() | Continuous | — | Prior over | ||
| Dirichlet() | Continuous | — | Prior over , LDA |