"LLM Core Study (5/6) — Math Intuition: Softmax, CE, KL, Gradient, LayerNorm"
Training a transformer is five equations behaving simultaneously: softmax must not saturate, cross-entropy gradients must stay alive, KL must give meaningful distance signal, gradients must not vanish or explode, and LayerNorm must hold variance in check. This lecture walks each equation's intuition, derivation, and failure mode.
0. Learning Objectives
- Sketch softmax and explain how temperature changes the distribution.
- Derive why softmax + cross-entropy yields the clean gradient \(p - y\).
- Explain KL divergence's asymmetry and its consequences for learning.
- Connect perplexity to cross-entropy.
- Diagnose vanishing/exploding gradients and prescribe residuals, normalisation, and initialisation.
- Write LayerNorm and RMSNorm equations by hand.
1. 핵심 요약
- Softmax: \(\mathrm{softmax}(z)_i = e^{z_i}/\sum_j e^{z_j}\). Temperature enters as \(z/T\).
- Softmax + cross-entropy: gradient = \(p - y\).
- KL\((p\|q) = \sum p \log (p/q)\); asymmetric; \(q = 0\) where \(p > 0\) diverges.
- Vanishing/exploding gradient: long products of Jacobians push norms to 0 or \(\infty\). Residuals + LayerNorm + correct initialisation fix it.
- LayerNorm normalises per-token over feature dimension; RMSNorm skips the mean.
2. Softmax
2.1 Definition
$$ \mathrm{softmax}(z)_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}. $$
Outputs lie in \([0, 1]\) and sum to 1. Invariant under shifting all \(z\) by a constant — subtract the max for numerical stability.
2.2 Temperature and saturation
Replace \(z\) with \(z/T\):
- \(T \to 0\): concentrates on argmax.
- \(T \to \infty\): uniform distribution.
- Large \(z\) entries saturate softmax; gradients vanish. This is why Lecture 1's \(\sqrt{d_k}\) scaling matters.
2.3 Gradient
For \(p = \mathrm{softmax}(z)\):
$$ \frac{\partial p_i}{\partial z_j} = p_i (\delta_{ij} - p_j). $$
If a single \(p_i\) is close to 1, then \(p_i (1 - p_i) \to 0\). The whole layer's gradient effectively disappears. Initialisation, normalisation, and \(\sqrt{d_k}\) all exist to keep softmax inputs in a healthy range.
2.4 Stable implementation
def softmax_stable(z, dim=-1):
z = z - z.max(dim=dim, keepdim=True).values
e = z.exp()
return e / e.sum(dim=dim, keepdim=True)
3. Cross-Entropy
3.1 Definition
For one-hot \(y\) and predicted distribution \(p\):
$$ \mathrm{CE}(y, p) = -\sum_i y_i \log p_i = -\log p_{i^*}. $$
3.2 Softmax + CE gradient
With \(\mathcal{L} = -\log p_{i^*}\),
$$ \mathcal{L} = -z_{i^*} + \log \sum_j e^{z_j}, $$
$$ \frac{\partial \mathcal{L}}{\partial z_j} = -\delta_{ji^*} + \frac{e^{z_j}}{\sum_k e^{z_k}} = p_j - y_j. \ \square $$
This clean result is the reason softmax + CE is the default classification head. Implementing them as separate operations would lose numerical stability and gradient simplicity; libraries fuse them (F.cross_entropy).
3.3 LLM training loss
Language modelling is per-position classification:
$$ \mathcal{L}_{\text{LM}} = -\frac{1}{T} \sum_{t=1}^{T} \log p_t(y_t). $$
3.4 Label smoothing
$$ \tilde{y}_j = (1 - \epsilon) y_j + \epsilon / K. $$
Reduces overconfidence; mildly hurts calibration.
4. Perplexity
$$ \mathrm{PPL} = \exp(\mathcal{L}_{\text{LM}}). $$
Interpretation: the average effective branching factor when the model picks the next token. Caveat: perplexity depends on the tokenizer, so it is not portable across vocabularies. It also says nothing about factuality.
5. KL Divergence
5.1 Definition
$$ \mathrm{KL}(p \,\|\, q) = \sum_i p_i \log \frac{p_i}{q_i}. $$
- \(\ge 0\), equal to 0 iff \(p = q\).
- Asymmetric: \(\mathrm{KL}(p\|q) \neq \mathrm{KL}(q\|p)\).
- Diverges if \(p_i > 0\) and \(q_i = 0\).
5.2 Relationship to cross-entropy
$$ \mathrm{CE}(p, q) = H(p) + \mathrm{KL}(p \,\|\, q). $$
If \(p\) is one-hot, \(H(p) = 0\), so CE minimisation equals KL minimisation.
5.3 LLM applications
- Distillation (Lecture 2): KL(teacher‖student).
- RLHF/DPO: KL penalty keeps the policy near a reference model:
$$ \mathcal{L}_{\text{RLHF}} = -\mathbb{E}[r(y)] + \beta \mathrm{KL}(\pi_\theta \,\|\, \pi_{\text{ref}}). $$
5.4 Forward vs reverse KL
- Forward \(\mathrm{KL}(p \| q)\) is mode-covering (\(q\) must spread to cover all of \(p\)).
- Reverse \(\mathrm{KL}(q \| p)\) is mode-seeking (\(q\) may collapse on one peak of \(p\)).
This is why "matching the reference distribution" feels different from "matching a teacher distribution," even though both are KL minimisations.
6. Vanishing & Exploding Gradients
6.1 Origin
The gradient through \(L\) layers is a product of Jacobians:
$$ \frac{\partial \mathcal{L}}{\partial h^{(0)}} = \prod_{\ell=1}^{L} \frac{\partial h^{(\ell)}}{\partial h^{(\ell-1)}} \cdot \frac{\partial \mathcal{L}}{\partial h^{(L)}}. $$
Singular values < 1 across all layers → product → 0. Singular values > 1 → product → \(\infty\).
6.2 Fix 1 — residual connections (He 2016)
\(h^{(\ell)} = h^{(\ell-1)} + f^{(\ell)}(h^{(\ell-1)})\) gives Jacobian \(I + f'\). The identity component lets gradients pass through unchanged.
6.3 Fix 2 — normalisation
LayerNorm / RMSNorm enforce stable activation variance per token (see §7).
6.4 Fix 3 — initialisation
- He (Kaiming) initialisation: \(W \sim \mathcal{N}(0, 2/n_{\text{in}})\) for ReLU.
- Xavier: \(\mathcal{N}(0, 1/n_{\text{in}})\) for tanh/sigmoid.
- Transformers typically use \(\sigma = 0.02\) and additionally scale output projections by \(1/\sqrt{2L}\) ("scaled init").
7. LayerNorm & RMSNorm
7.1 LayerNorm (Ba 2016)
Normalise per token over the feature dimension:
$$ \mu = \frac{1}{d}\sum_i h_i,\ \ \sigma^2 = \frac{1}{d}\sum_i (h_i - \mu)^2, $$
$$ \mathrm{LN}(h)_i = \gamma_i \frac{h_i - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta_i. $$
Independent of batch size, ideal for token-level models.
7.2 Pre-Norm vs Post-Norm
- Post-Norm (original): \(\mathrm{LN}(h + f(h))\). Unstable in deep models.
- Pre-Norm (modern default): \(h + f(\mathrm{LN}(h))\). Gradients glide along the residual path.
7.3 RMSNorm (Zhang & Sennrich 2019)
Skip the mean:
$$ \mathrm{RMS}(h) = \sqrt{\frac{1}{d}\sum_i h_i^2 + \epsilon},\ \ \mathrm{RMSNorm}(h)_i = \gamma_i \frac{h_i}{\mathrm{RMS}(h)}. $$
~30 % cheaper than LayerNorm with negligible quality loss. Default in LLaMA, Mistral, Qwen.
7.4 Diagram — one pre-norm block
8. AdamW intuition
Standard equations (also in Lecture 2):
$$ m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t,\ \ v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2, $$ $$ \hat{m}_t = m_t/(1-\beta_1^t),\ \ \hat{v}_t = v_t/(1-\beta_2^t), $$ $$ \theta_t = \theta_{t-1} - \eta \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \eta \lambda \theta_{t-1}. $$
- Momentum \(m\): EMA of gradients; denoises.
- Variance \(v\): EMA of squared gradients; scales each parameter individually.
- Decoupled weight decay (W): applied to the parameter directly, not folded into the gradient — different from L2 regularisation in subtle but real ways.
Typical LLM-scale settings: \(\beta_1 = 0.9, \beta_2 = 0.95, \lambda = 0.1, \eta = 3 \times 10^{-4}\), cosine warmup-decay.
9. Quick Recap — Answer Before You Peek
Five core questions this article answered. Cover the answers, give a one-line response yourself, then check.
Q1. Why is the gradient of softmax + cross-entropy exactly \(\hat{p} - y\) (one-hot)?
Answer The softmax derivative and the cross-entropy \(\log\) cancel cleanly through the chain rule, yielding \(\hat{p} - y\). The \(\sum e^{z_j}\) denominator vanishes — almost magically.
Why This elegance is why softmax + CE became the classification standard. PyTorch's cross_entropy() fuses log + softmax + NLL for numerical stability and to exploit the same simplification. (Sections §2·§3.)
Q2. What is the intuition for KL Divergence being asymmetric?
Answer \(\mathrm{KL}(P \| Q) \neq \mathrm{KL}(Q \| P)\). \(\mathrm{KL}(P \| Q)\) measures from P's view — high cost where P is large but Q is small. The reverse direction differs. Why Forward KL (\(P\) data, \(Q\) model): mass-covering — Q spreads to cover every P mode. Reverse KL (\(Q\) data, \(P\) model, used in VI): mode-seeking — Q concentrates on one P mode. RLHF's KL penalty uses the reverse direction. (Section §5.)
Q3. Connect perplexity to cross-entropy in one line.
Answer Perplexity = \(e^{\mathrm{CE}}\) or \(2^{\mathrm{CE}}\) (log base dependent). Lower CE → lower perplexity → better model. Why Perplexity's intuition: "the effective number of choices the model agonizes over for the next token." PPL=10 means roughly 10 candidates in contention. The canonical LLM evaluation metric. (Section §4.)
Q4. How do residual connections preserve gradient flow? LayerNorm vs RMSNorm formula difference?
Answer Residual: \(y = F(x) + x\). Backprop gives \(\frac{\partial y}{\partial x} = \frac{\partial F}{\partial x} + I\) → the identity path is always alive, blocking vanishing. LayerNorm: \(\frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta\). RMSNorm: \(\frac{x}{\sqrt{\text{RMS}(x)^2 + \epsilon}} \cdot \gamma\) — no mean subtraction. Why Residual: the central insight of ResNet (He 2016) and a standard in every Transformer block. RMSNorm drops mean subtraction for compute savings — simplification justified by distributional assumption, adopted by LLaMA. (Sections §6·§7.)
Q5. AdamW's weight decay vs L2 regularization — what's the difference?
Answer L2 adds \(\lambda \|w\|^2\) to the loss, so it adapts with the gradient. Weight decay applies \(w \leftarrow w - \eta \lambda w\) directly in the optimizer step, decoupled from adaptive moments. Why Adam's adaptive learning rate dilutes L2 regularization (Loshchilov & Hutter 2017). AdamW decouples weight decay to deliver real regularization. Standard for Transformer training. (Section §8.)
If four or five came out as one-liners, the softmax-CE · KL · perplexity · residual · normalization core of LLM math is in place.
10. Practice
- Implement softmax + CE in NumPy and verify the analytic gradient \(p - y\) for a 5-class example with target class 2.
- Compute \(\mathrm{KL}(p\|q)\) and \(\mathrm{KL}(q\|p)\) for hand-picked \(p, q\) and quantify the asymmetry.
- Train a 30-layer MLP with Pre-Norm vs Post-Norm and compare the layer-wise gradient norms.
11. Further reading
- Goodfellow, Bengio, Courville, 2016, Deep Learning. Chapters 6 and 8.
- He et al., 2015, Delving Deep into Rectifiers. arXiv:1502.01852.
- He et al., 2016, Deep Residual Learning. arXiv:1512.03385.
- Ba et al., 2016, Layer Normalization. arXiv:1607.06450.
- Zhang & Sennrich, 2019, RMSNorm. arXiv:1910.07467.
- Loshchilov & Hutter, 2019, AdamW. arXiv:1711.05101.
- Kullback & Leibler, 1951, On Information and Sufficiency. Original KL paper.
- 3Blue1Brown, Neural networks video series — the best visual intuition.
Part 5 of 6 in the LLM Core Study series. Part 6 closes with the learning roadmap: 12-week course, 10 papers, 10 repositories.
Series overview: Series index
댓글
댓글 쓰기