Training is a loop that nudges millions of numbers downhill on an error surface — forward, loss, backward, step, repeat — and once you can picture that loop, words like gradient, learning rate, and AdamW stop being magic.
ai-eng-wiki/examples/ml/train_loop.py"Training" is just an optimization loop: you measure how wrong the model is with a single number (the loss), figure out which direction to nudge each of the model's numbers (its parameters) to make that number smaller, take a small step in that direction, and repeat millions of times until the loss stops dropping. The direction is the gradient; stepping against it is gradient descent.
If you've ever written a feedback controller, a hill-climbing search, or a Newton-Raphson root-finder, you already have the mental model: define an objective, compute its slope, step toward the optimum, loop until convergence. A neural network is the same idea with a few billion knobs instead of one, and a clever trick (backpropagation) for computing the slope of all of them in a single pass.
You will rarely write a training loop from scratch on the job — but every concept in this lesson is load-bearing vocabulary for the rest of AI engineering, and interviews assume you own it cold.
Interviewers use this topic as a filter. If you can explain why cross-entropy and not MSE for classification, what the learning rate trades off, and what backprop actually computes, you sound like someone who has trained a model. If you can't, no amount of API experience hides it.
The words first.
0 would be perfect.32 images at once).Step by step.
Remember this: Training just repeats one cycle — predict, measure how wrong, find which way each weight should move, take a small step — millions of times.
A model is a function with tunable parameters: it takes an input x and produces a prediction ŷ ("y-hat"). To improve it, we first need to score it. A loss function maps a prediction and the true answer y to a single non-negative number: 0 means perfect, larger means worse. Training = making the average loss over your data as small as possible.
Two losses cover almost everything you'll meet.
Regression → Mean Squared Error (MSE). When the target is a continuous number (price, temperature), penalize the squared gap between prediction and truth:
Every symbol: n is the count of examples, i picks out one example, ŷ_i is the model's prediction for example i, and y_i is the true value. You subtract truth from prediction, square it (to make every error positive), then average across all examples.
Concrete example: Suppose you have three examples with true values y = [2, 4, 6] and the model predicts ŷ = [2.5, 3.5, 6.2]. The errors are [0.5, -0.5, 0.2]. Squared: [0.25, 0.25, 0.04]. Average: (0.25 + 0.25 + 0.04) / 3 ≈ 0.18. That's the MSE — a single number measuring how far off the predictions were on average. Bigger errors hurt more: being off by 2 costs 4 in the sum, being off by 1 costs only 1.
This loss squashes all wrongness into one scalar that the optimizer can drive toward zero.
Here n is the number of examples, i indexes them, ŷ_i is the model's prediction for example i, and y_i is the true value. Squaring makes every error positive and punishes big misses far more than small ones (being off by 4 costs 16, off by 2 costs 4). It measures distance.
Classification → Cross-Entropy. When the target is a category (cat/dog, or one of 50,000 possible next tokens), the model outputs a probability for each class, and we want to reward it for putting high probability on the correct class. Cross-entropy for a single example is:
Every symbol: p_correct is the probability the model assigned to the true class, and log is the natural logarithm (base e).
Concrete example: The model outputs probabilities for three classes; suppose the true class is class 2, and the model assigned p_2 = 0.8. Then CE = -log(0.8) ≈ -(-0.223) ≈ 0.22 — a small loss because the model was confident and correct. Now suppose the model were very wrong: p_correct = 0.01. Then CE = -log(0.01) ≈ 4.6 — huge penalty. If the model outputs p_correct = 0.5 (uncertain), CE = -log(0.5) ≈ 0.69 — medium loss. Notice: the worse the model is (lower probability on the true answer), the larger the loss. And the gradient of cross-entropy stays large exactly when the model is wrong — perfect for learning.
where p_correct is the probability the model assigned to the true class. If the model is confident and right (p = 0.99), the loss is -log(0.99) ≈ 0.01 — tiny. If it's confident and wrong (p = 0.01 on the true class), the loss is -log(0.01) ≈ 4.6 — huge. It measures surprise: how shocked the model should be by the right answer.
Two terms you'll hear constantly live right here. Logits are the raw, unbounded scores the network outputs per class (e.g. [2.0, 1.0, 0.1]). To turn logits into probabilities that sum to 1, you apply softmax: exponentiate each, then normalize.
Every symbol: z_k is the logit (raw model output) for class k, e is Euler's constant (~2.718), and the denominator sums the exponentials of all logits to normalize the distribution.
Concrete example: Logits from the model are [2.0, 1.0, 0.1] (three classes). First, exponentiate: e^2.0 ≈ 7.39, e^1.0 ≈ 2.72, e^0.1 ≈ 1.11. Sum them: 7.39 + 2.72 + 1.11 ≈ 11.22. Now divide each by the sum: p_1 = 7.39 / 11.22 ≈ 0.66, p_2 = 2.72 / 11.22 ≈ 0.24, p_3 = 1.11 / 11.22 ≈ 0.10. Notice the probabilities sum to 1.0 and the biggest logit gets the biggest probability, with smooth attenuation. This transformation turns raw scores into a valid probability distribution that cross-entropy can consume.
z_k is the logit for class k; the denominator sums over all classes j so the outputs form a valid probability distribution. Worked micro-example: logits [2.0, 1.0, 0.1] exponentiate to [7.39, 2.72, 1.11], sum 11.2, giving probabilities [0.66, 0.24, 0.10]. If the true class is the first, cross-entropy is -log(0.66) ≈ 0.42. If the true class were the last, it'd be -log(0.10) ≈ 2.3 — the same prediction is "good" or "bad" depending only on which class was correct.
Why not MSE for classification? MSE on probabilities barely punishes confident wrong answers and produces tiny, flat gradients when the model is very wrong (the place you most need a strong push). Cross-entropy's gradient stays large exactly when the model is badly wrong, so it learns faster and more stably. This pairing — softmax + cross-entropy — is the workhorse loss of every LLM, where "classification" means "predict the next token out of the vocabulary."
Now we have a number (loss) we want to minimize by adjusting parameters. Picture the loss as a landscape: the parameters are your coordinates, and the height is the loss. You're standing somewhere in the fog and want to reach the lowest valley. The strategy: feel the slope under your feet and step downhill.
The gradient, written ∇L ("grad L"), is the vector of partial derivatives of the loss with respect to every parameter. Each component answers: "if I increase this one parameter a hair, does the loss go up or down, and how steeply?" The gradient points in the direction of steepest increase, so to go downhill we step in the negative gradient direction. The update rule for a parameter θ ("theta") is:
Every symbol: θ is one parameter (one of the millions of weights), ∂L/∂θ is the gradient — how much the loss changes if you nudge this parameter slightly — and η (eta) is the learning rate, controlling step size. The arrow ← means "replace the old value with the new value."
Concrete example: Say a weight θ = 5.0, the gradient ∂L/∂θ = 0.8 (positive, meaning increase this weight raises the loss), and learning rate η = 0.1. Then θ ← 5.0 - 0.1 * 0.8 = 5.0 - 0.08 = 4.92. The weight moved slightly downhill: since the gradient was positive (uphill), we subtract it to go downhill. If the gradient were negative (e.g. ∂L/∂θ = -2.0), then θ ← 5.0 - 0.1 * (-2.0) = 5.0 + 0.2 = 5.2 — we'd increase the weight because decreasing it goes downhill. Every parameter gets nudged: small steps, many times, in the direction that shrinks the loss.
θ is a parameter, ∂L/∂θ is the gradient component for it (the slope), and η ("eta") is the learning rate — how big a step you take. The ← means "assign." That's the entire algorithm: subtract a scaled slope from every parameter, repeat.
Worked micro-example. Fit ŷ = w·x + b to three points (1,2), (2,4), (3,6) (the true rule is y = 2x). Start at w = 0, b = 0, so every prediction is 0 and the errors (ŷ − y) are (−2, −4, −6). The MSE gradients are:
Plugging in: ∂L/∂w = (2/3)[(−2)(1)+(−4)(2)+(−6)(3)] = (2/3)(−28) ≈ −18.7 and ∂L/∂b = (2/3)(−12) = −8. Both gradients are negative — meaning increasing w and b would lower the loss. With learning rate η = 0.1:
w ← 0 − 0.1·(−18.7) = 1.87 and b ← 0 − 0.1·(−8) = 0.8.
One step moved w from 0 to 1.87, already close to the true value of 2. Repeat a few dozen times and it converges. That's gradient descent.
The learning rate η controls step size, and it's the single most consequential hyperparameter. Too small and training crawls — thousands of wasted steps inching toward the valley. Too large and you overshoot the bottom on every step, bouncing up the far wall and diverging (loss climbs to infinity or NaN). There's a Goldilocks band, and in practice people don't keep it fixed: they warm up (start tiny so early steps don't blow up) then decay it over training (smaller steps near the minimum for fine settling).
Play with it directly — drag the learning rate and watch the ball either glide into the valley, crawl, or fly off the surface:
→ Descending — following the negative gradient downhill. Keep stepping.
The gradient formulas above sum over all n examples. With millions of examples, computing the full-dataset gradient for every single step is absurdly slow. The fix is Stochastic Gradient Descent (SGD): estimate the gradient from a small random minibatch of examples (say 32 or 256) instead of the whole dataset. Each estimate is noisy but unbiased, and you get to take many cheap steps instead of one expensive one.
The noise is a feature, not just a tax. The jitter from random batches helps the optimizer skip past shallow bad spots and tends to find flatter minima that generalize better — a mild, free regularization effect. (Regularization = anything that fights overfitting, i.e. memorizing the training data instead of learning the pattern.)
This gives you the three timing words that trip up beginners:
| Term | Definition |
|---|---|
| Batch size | How many examples per gradient estimate (e.g. 256). |
| Step (iteration) | One forward+backward+update on one batch. One step = one parameter update. |
| Epoch | One full pass over the entire training set. steps_per_epoch = dataset_size / batch_size. |
So "we trained for 3 epochs with batch size 256 on 1M examples" means 3 × (1,000,000 / 256) ≈ 11,700 parameter updates. LLM pretraining is usually described in steps or tokens, not epochs, because it sees each example roughly once.
In the toy example we had two parameters and wrote their gradients by hand. A real network has millions to billions, arranged in layers where each layer's output feeds the next. How do you get the gradient for a weight buried five layers deep? The honest answer is the chain rule from calculus — but the insight of backpropagation is the bookkeeping that makes it cheap.
A network's prediction is a composition of functions: loss(layer_N(...layer_2(layer_1(x)))). The chain rule says the derivative of a composition is the product of the derivatives of each step. Backprop runs in two passes:
x through the layers to compute the prediction and the loss, caching each layer's intermediate output.The key efficiency point — the thing interviewers want — is that naively, computing each parameter's gradient separately would re-derive shared sub-expressions over and over (exponential blowup). Backprop computes the gradient of all parameters in one backward pass, at roughly the same cost as one forward pass, by reusing each intermediate result exactly once (this is dynamic programming on the computation graph). That O(network size) cost — rather than O(network size × parameters) — is why deep learning is computationally feasible at all. Modern frameworks call this autograd: they record the graph of operations during the forward pass and replay it backward automatically, so you never write a derivative by hand.
Plain SGD takes a step proportional to the raw gradient. It works, but it's twitchy: it zig-zags across narrow valleys and stalls on flat plateaus. Optimizers are smarter recipes for turning gradients into updates.
| Optimizer | Idea (intuition) | When |
|---|---|---|
| SGD | Step directly down the (minibatch) gradient. | Simple, well-tuned vision models; strong baseline. |
| SGD + Momentum | Accumulate a running average of past gradients — a "velocity" — so consistent directions build speed and noise cancels out. Like a heavy ball rolling downhill. | Most CNN training. |
| Adam | Per-parameter adaptive steps: divide each parameter's step by a running estimate of its own gradient magnitude, so rarely-updated and wildly-scaled parameters all move sensibly. Momentum + auto-scaling in one. | The default for transformers/NLP. |
| AdamW | Adam with weight decay decoupled from the gradient (a cleaner way to pull weights toward zero for regularization). | The de-facto standard for training and fine-tuning LLMs. |
You don't need the update equations memorized for IC3. You need the story: momentum smooths the path using gradient history; Adam additionally gives each parameter its own adaptive step size; AdamW is the LLM default. When someone says "we used AdamW with a cosine schedule and linear warmup," you now know that means: the adaptive-momentum optimizer, learning rate ramped up then smoothly decayed.
Every framework, every fine-tune, every LLM pretrain is this five-line rhythm repeated until the loss flattens:
forward → loss → backward → step → repeat
(predict) (score) (get grads) (update)That's it. Everything else — architecture, data, schedules — is detail bolted onto this skeleton.
First, the whole loop in pure NumPy so nothing is hidden — linear regression by gradient descent, exactly the math from §3.2:
import numpy as np
# tiny dataset; the true rule is y = 2x (the model has to discover it)
X = np.array([1.0, 2.0, 3.0, 4.0])
y = np.array([2.0, 4.0, 6.0, 8.0])
w, b = 0.0, 0.0 # parameters we will learn (start at zero)
lr = 0.01 # learning rate η
n = len(X)
for step in range(1000):
y_hat = w * X + b # 1. FORWARD: predict
error = y_hat - y # residual (ŷ − y)
loss = np.mean(error ** 2) # 2. LOSS: mean squared error
grad_w = (2 / n) * np.sum(error * X) # 3. BACKWARD: ∂L/∂w
grad_b = (2 / n) * np.sum(error) # ∂L/∂b
w -= lr * grad_w # 4. STEP: descend the gradient
b -= lr * grad_b
if step % 200 == 0:
print(f"step {step:4d} loss {loss:7.4f} w {w:.3f} b {b:.3f}")Line by line: the forward pass computes predictions; error and loss score them; grad_w/grad_b are the hand-derived gradients (this is what backprop automates for big networks); the two -= lines are the step. Run it and the loss falls toward 0 while w → 2 and b → 0. Bump lr to 1.0 and watch it diverge to NaN — that's an oversized learning rate, live.
Setup (lines 186–191): Create tiny X and y (true rule is y = 2x), initialize parameters w and b to zero (guessing is allowed), set learning rate to 0.01, and note n = 4 examples.
Loop body (lines 193–202):
w, add bias b, get predictions y_hat.∂L/∂w sums error * X (how the weight should move); ∂L/∂b sums the errors (how the bias should move). These are exactly the chain rule applied to MSE by hand.w and b move downhill.The whole loop is: measure wrongness, compute slopes, nudge the numbers, repeat. Thousands of tiny steps land you at the true rule.
Now the same loop in PyTorch, where autograd computes the gradients for you:
import torch
import torch.nn as nn
X = torch.tensor([[1.], [2.], [3.], [4.]])
y = torch.tensor([[2.], [4.], [6.], [8.]])
model = nn.Linear(1, 1) # ŷ = wx + b
opt = torch.optim.SGD(model.parameters(), lr=0.01) # optimizer holds the params
loss_fn = nn.MSELoss()
for step in range(1000):
opt.zero_grad() # clear last step's gradients (they accumulate!)
loss = loss_fn(model(X), y) # FORWARD + LOSS in one line
loss.backward() # BACKWARD: autograd fills every param's .grad
opt.step() # STEP: optimizer updates params using .gradThe structure is identical — zero_grad → forward → loss → backward → step — but loss.backward() walks the computation graph and computes ∂loss/∂θ for every parameter automatically, and opt.step() applies the update rule. Swap SGD for torch.optim.AdamW and you've upgraded optimizers in one token. One subtlety worth internalizing: PyTorch accumulates gradients across backward() calls, so you must zero_grad() each step or your gradients pile up and training breaks — a classic first-day bug. The full runnable version, with both implementations and a divergence demo, lives in examples/ml/train_loop.py.
Setup (lines 213–218): Create X and y as tensors, build a linear model (one weight and one bias), pick SGD optimizer (holding pointers to the model's parameters), and choose MSE loss.
Loop body (lines 220–224):
opt.zero_grad() — erase all accumulated gradients from the previous step. PyTorch adds new gradients to whatever was there; forgetting this causes them to pile up and destroy training.loss.backward() — the magic: autograd walks the computation graph backward from the loss and fills every parameter's .grad attribute with ∂loss/∂θ. This is backpropagation, automated. The NumPy gradients we hand-computed are now computed instantly for millions of parameters.opt.step() — the optimizer reads .grad for each parameter and applies the update rule (for SGD: θ ← θ - η * grad).The dance is identical to NumPy, but PyTorch auto-computes the slopes and applies the update, letting you swap optimizers (SGD → AdamW) in one word.
Training loss keeps falling; validation loss falls then rises as the model memorizes. The gap is overfitting — you stop at the validation minimum (early stopping).
NaN. Faster only until it's catastrophic.−log of the probability assigned to the true class. Cross-entropy keeps a strong gradient exactly when the model is confidently wrong — where MSE's gradient goes flat — so classifiers (including LLMs predicting the next token) learn faster and more stably.θ ← θ − η·∂L/∂θ. The learning rate η is step size. Too small → painfully slow convergence; too large → overshoot the minimum and diverge (loss → NaN). In practice people warm up then decay it.dataset_size / batch_size steps). Minibatches make each step cheap (so you take many) and inject useful gradient noise that regularizes and helps escape poor minima.∂loss/∂θ for every parameter in a single backward pass by applying the chain rule with cached intermediates — O(network size), not O(size × params). SGD steps down the raw minibatch gradient; momentum adds a velocity (running average of gradients) to smooth and accelerate; Adam additionally scales each parameter's step by a running estimate of its own gradient magnitude (adaptive per-parameter rates), and AdamW decouples weight decay — the LLM default.You're ready to move on when you can narrate the forward → loss → backward → step loop unprompted, say why cross-entropy beats MSE for classification, and explain what the learning rate trades off — without reaching for notes.
Next: Neural Networks → — stack these trainable layers into a network and see what the forward pass and backprop actually run over, then on to Fine-tuning, which is this loop applied to a pretrained model.