The half of RL hiring people forget. How many model copies GRPO actually holds, why your GPUs idle on long-tail rollouts, when async beats sync (and what staleness costs you), VeRL vs TRL vs slime vs AReaL, Megatron vs FSDP, why RL training is nondeterministic, and the MoE train-inference router mismatch that silently collapses runs.
RLHF/RLVR is not one model training — it is a distributed dataflow with two engines that disagree. A rollout engine (vLLM/SGLang) generates completions fast, and a training engine (FSDP/Megatron) computes gradients correctly. The algorithm (PPO, GRPO, GSPO — see /finetuning/ppo-grpo-and-variants) is maybe 200 lines. The other 95% of the difficulty is: how many model copies fit, who is idle while someone else generates, how stale the data is allowed to be, and whether the logprobs your trainer computes even match the logprobs your sampler used. Get the infra wrong and a mathematically correct loss diverges anyway.
This lesson makes you fluent in exactly the questions a "research" interview still asks under the hood.
The words first.
Step by step.
That is why ~3 models plus optimizer state sit in GPU memory at once. And generation dominates the clock because writing N tokens means N sequential forward passes, while the gradient step is just one backward pass.
Remember this: RL holds several models in memory because each gives a different signal, and the slow part is writing answers one token at a time, not learning from them.
The single most common whiteboard question. For PPO you hold four logical models: policy (trainable), reference (frozen, for KL), critic/value (trainable), and reward model (frozen, if learned). GRPO's core infra win is deleting the critic — it estimates the baseline from the group instead of a learned value head (the advantage is A_i = (r_i - mean(r)) / std(r) over G samples). That removes one trainable model and its entire training loop, cutting memory/compute roughly 30–50% versus PPO.
So in a synchronous GRPO step you typically have 2–3 copies: policy + reference, plus a reward model only if reward is learned (RLVR with a verifier function needs none — see /finetuning/rl-post-training).
The subtlety interviewers want: not all copies cost the same. Only the policy carries gradients and optimizer state. With Adam in mixed precision the policy's footprint is roughly: 2 bytes params + 2 bytes grad + (4+4+4) bytes of fp32 master/momentum/variance ≈ 16 bytes/param. Reference and reward models are inference-only (2 bytes/param, no grad, no optimizer). That's why "two extra models" is cheaper than it sounds, and why the first thing you cut under pressure is optimizer/activation memory on the policy, not the frozen copies.
Levers, cheapest-first:
G completions per prompt that share a long common prefix; encode the prefix once and reuse, saving ~1/G of prefix FLOPs/memory and letting you raise G.Generation, not the gradient step, is usually the wall-clock bottleneck in RL. Both modern engines use continuous batching: new requests join the running batch at the iteration (token) level instead of waiting for a fixed batch window, which kills head-of-line blocking.
For RL specifically, your G rollouts per prompt share the prompt prefix by construction — RadixAttention is a natural fit. See /inference for the engine internals.
Synchronous RL's dirty secret: the batch finishes when the slowest generation finishes. Sample 32 prompts averaging ~100 tokens but with one 800-token outlier, and 31 GPUs idle while one decodes. With reasoning models (long, variable CoT), the tail is brutal and throughput craters.
Fixes, in order of how invasive they are:
In asynchronous RL, rollout workers continuously generate into a buffer while training workers consume scored trajectories — generation is decoupled from optimization. Reported 1.53×–2.24× end-to-end speedups; AReaL reports up to 2.57×.
The price is staleness: trajectories were sampled by an older policy π_old, but you're updating π_θ. That makes the data off-policy, so you must importance-weight (and clip/truncate) — exactly the ratio w = π_θ/π_old that PPO/GRPO already use, now doing real off-policy work instead of near-identity correction.
π_proxy ≈ α·π_behavior + (1−α)·π_target for ~1.8× with comparable quality).Staff-level framing: async trades statistical efficiency (per-sample) for hardware efficiency (per-second). Past a staleness threshold, IS ratios fat-tail, variance explodes, and you give back the speedup in instability. The job is tuning the buffer so the policy never drifts faster than your weight-sync cadence.
For MoE (60%+ of 2025 open releases — DeepSeek-R1, Kimi K2, Mistral Large 3), add Expert Parallelism: shard experts across GPUs and route tokens to them, which requires all-to-all communication every layer. That all-to-all is the bottleneck and demands NVLink/InfiniBand; wide-EP reports ~1.8–1.9× per-GPU throughput. At long context, hide comms behind compute with sequence-level overlap (ISO, ~35% prefill reduction on 4090) or ring attention. See /transformers for MoE/attention internals.
RL training is nondeterministic by default, and people wrongly blame the GPU. The real cause: LLM kernels are batch-size dependent. Different batch sizes change reduction order in RMSNorm/matmul/attention, and floating-point addition is non-associative, so the same prompt yields different logits at batch size 1 vs 32. In RL this is poison: your rollout engine samples at one batch size, your trainer recomputes logprobs at another, and the "off-policy" ratio you measure is partly a kernel artifact, not real policy drift.
The Thinking Machines / SGLang fix reframes this as an engineering bug: use batch-invariant kernels that fix the reduction order independent of batch size (deterministic RMSNorm, fixed matmul block size, attention with fixed split sizes / num-splits=1, seeded sampling). Cost ≈ 34% slowdown, clawed back to ~2.8× faster with CUDA graphs. Worth it when you need reproducible trajectories and trustworthy eval comparisons (see /evals).
Related footgun: the gradient-scaling bug. Raw torch.distributed.all_reduce has no backward wrapper, so under sequence parallelism gradients come out off by exactly the SP size. Fix: torch.distributed.nn.all_reduce, or scale by 1/num_processes explicitly.
In line 174 of the code, the trainer recomputes logprobs on the exact same tokens the sampler encoded. You'd expect identical numbers, but they differ even at step 0. Why: LLM kernels optimize for their input batch size. RMSNorm computes mean and variance in a fixed block order (e.g., block size 256); at batch size 1 the blocks are ordered one way, at batch size 32 differently. Floating-point addition is non-associative — (a + b) + c ≠ a + (b + c) in floating point — so different orderings give different sums. A single-token logprob ends up 0.02 nats off. Across a rollout and training engine that's ~10% of tokens drifting per-layer; on MoE it's ~94% of tokens affected in at least one layer. The "fix" is batch-invariant kernels: use fixed block sizes and deterministic seeding so the same input → same output regardless of batch size. Cost: ~34% slowdown (recoverable to 2.8× faster via CUDA graphs). Skip it if you're prototyping; do it if you're debugging why eval-to-eval parity breaks.
This is where dense intuition breaks. In an MoE, routing is non-deterministic across engines: even with identical router weights, vLLM (rollout) and FSDP/Megatron (training) disagree on ~10% of routers per forward pass, and ~94% of tokens differ in expert assignment in at least one MoE layer. After a policy update the activated expert set shifts further. Consequence: the importance ratio π_θ/π_old spikes on routing-misaligned tokens, PPO/GRPO clipping fires chaotically, gradients destabilize, and the run collapses — often with nothing else changed from a working dense recipe. On Qwen3-30B-A3B, baseline GRPO collapsed in 3/3 runs.
Fixes, two flavors:
w outside [τ_low, τ_high]) is the cheap, coarse cousin: it masks the blow-ups but discards signal.IS_seq = (log π_θ(y) − log π_old(y)) / |y|, with sequence-level clipping. If experts disagree in layer k but re-align by k+2, the sequence-level ratio averages the fluctuation out — no replay, works directly with rollout-engine likelihoods. Qwen adopted GSPO in production; it beats GRPO on AIME/code at 30B despite clipping more tokens, which is itself the tell that GRPO's token-level gradients were just noisy. DeepSeek-V3.2 later reintroduced a value critic (VAPO) for 60.4 AIME with zero crashes and ~40% fewer steps than DAPO — a sign value-based methods are maturing again at scale.A stripped two-engine GRPO loop. Real frameworks (VeRL/slime) add sharded weight sync, distributed buffers, and reward services — this shows the infra-critical parts: prefix-shared rollout, recomputing trainer logprobs, and the IS correction that makes off-policy/MoE survivable.
import torch
def grpo_infra_step(prompts, policy, ref, sampler, G=8, eps=0.2, kl_coef=0.04,
is_tau=(0.2, 5.0)):
# 1) ROLLOUT (separate engine, e.g. vLLM/SGLang). Shares the prompt prefix
# across the G samples (Prefix Grouper). Returns the sampler's OWN logprobs.
rollouts = sampler.generate(prompts, n=G, return_logprobs=True) # off-policy sourceRollout (steps 1–2): the sampler (vLLM/SGLang, running in a separate process) generates G=8 answers per prompt, reusing the shared prompt prefix via Prefix Grouper. It returns the exact logprobs it computed during sampling — these are the "source of truth" for old-policy likelihood. Reward (step 3): verifiable signals (binary success/fail) have no learned reward model, so they're free. Advantage (step 4): the group-relative formula pools all G rewards, computes how far each is from the mean, and scales by the group's spread — no critic needed. Recompute logprobs (step 5): critical gap. The trainer (FSDP) recomputes logprobs on the same tokens at potentially a different batch size, and on MoE the router re-routes tokens. The recomputed logp_new differs from the sampler's logp_old even though the policy is identical — that's batch-variant kernels, not actual policy drift. Importance sampling (steps 6–7): the ratio w = logp_new / logp_old captures both real policy changes and fake batch-size shifts. Truncation (keep mask) suppresses the fake spikes; this is the cheap fix before doing GSPO or R3. Loss (steps 8–9): the PPO-clipped objective (min of clipped and unclipped ratio times advantage) plus an explicit KL penalty to reference — GRPO has no critic, so KL is a separate loss term. The whole snippet is one synchronous step: sample-score-compute-update-repeat.
toks, mask, logp_old = rollouts.tokens, rollouts.mask, rollouts.logprobs
# 2) REWARD: verifiable (RLVR) -> no reward model in memory. Binary {0,1}.
r = torch.tensor([verify(p, t) for p, t in zip(prompts, rollouts.text)])
# 3) GROUP-RELATIVE ADVANTAGE (no critic). Clamp std to avoid the tiny-denominator
# blowup when a whole group is right/wrong (the classic GRPO instability).
r = r.view(-1, G)
adv = (r - r.mean(1, keepdim=True)) / (r.std(1, keepdim=True) + 1e-6)Every symbol: the left side r is a batch of G=8 rewards, the numerator is each reward minus the group mean, and the denominator is how much rewards vary in that group plus a tiny 1e-6 safety buffer. Let's walk a concrete example. Say your G=8 samples got rewards [10, 12, 8, 11, 13, 9, 11, 12]. The mean is (10+12+8+11+13+9+11+12)/8 = 96/8 = 12. Now compute each deviation: [10−12, 12−12, 8−12, 11−12, 13−12, 9−12, 11−12, 12−12] = [−2, 0, −4, −1, 1, −3, −1, 0]. The standard deviation is sqrt(mean of squares) = sqrt((4+0+16+1+1+9+1+0)/8) = sqrt(32/8) = sqrt(4) = 2.0. The final advantage is [−2/2, 0/2, −4/2, −1/2, 1/2, −3/2, −1/2, 0/2] = [−1.0, 0.0, −2.0, −0.5, 0.5, −1.5, −0.5, 0.0]. Notice: bad answers get negative advantage (discourage them), good ones get positive (encourage them), and they're scaled by how much they stand out relative to the group. That's the whole trick — normalize within-group so a run of 8 all-correct answers doesn't blow up the denominator.
adv = adv.reshape(-1, 1)
# 4) RECOMPUTE logprobs in the TRAINING engine. These DIFFER from logp_old because
# of batch-size non-invariance (and MoE router disagreement). That gap is why
# we need IS, not because the policy already moved.
logp_new = policy.logprobs(toks) # requires grad
logp_ref = ref.logprobs(toks).detach() # frozen
# 5) IMPORTANCE SAMPLING correction for off-policy / staleness / router shift.
ratio = torch.exp(logp_new - logp_old) # w = pi_theta / pi_oldHere's the core: the ratio w = π_θ / π_old tells you how much the policy changed per token. In log space (what we have), it's log w = logp_new − logp_old, so w = exp(logp_new − logp_old). Let's say the sampler assigned logprob −0.5 to the token "analyze" (probability ~0.606), but the trainer recomputes and gets −0.6 (probability ~0.549). The log-difference is −0.6 − (−0.5) = −0.1, so w = exp(−0.1) ≈ 0.905. This means the trainer's policy is slightly less confident — a ratio below 1.0. If instead the trainer got −0.4 (probability ~0.670), the ratio is exp(−0.4 − (−0.5)) = exp(0.1) ≈ 1.105 — the policy got more confident. Under normal policy updates, ratios stay near 1.0; MoE routers disagreeing or stale data make ratios spike (0.2 or 5.0), which is the signal to truncate that token's gradient.
lo, hi = is_tau
keep = (ratio > lo) & (ratio < hi) # TIS-style truncation (token mask)
# 6) PPO-clipped surrogate on the kept tokens.
unclipped = ratio * adv
clipped = torch.clamp(ratio, 1 - eps, 1 + eps) * adv
pg = -torch.min(unclipped, clipped)
# 7) KL-to-reference as an explicit loss penalty (GRPO style), not folded into reward.
kl = torch.exp(logp_ref - logp_new) - (logp_ref - logp_new) - 1 # >= 0 estimator
loss = ((pg + kl_coef * kl) * mask * keep).sum() / (mask * keep).sum().clamp_min(1)
loss.backward() # BF16 gradient step; FP8 stays on rollout side
return loss
The honesty: this is synchronous and single-process. The `keep` mask is the crude TIS fix; for MoE you'd replace step 5 with a **sequence-level** ratio (GSPO) or feed replayed routing masks (R3). And `logp_old != logp_new` even at step 0 unless your kernels are batch-invariant.
## 5. Production tradeoffs
| Choice | Win | Cost / risk | Pick when |
|---|---|---|---|
| GRPO (critic-free) vs PPO | −30–50% memory, no critic loop | group-std blowup on all-right/all-wrong groups | verifiable rewards, scale |
| vLLM (PagedAttention) | <4% KV waste, ecosystem default | no prefix reuse | general rollout |
| SGLang (RadixAttention) | +29% tput, 6.4× prefix-heavy | LRU cache memory; no gain w/o prefix overlap | shared prompt / `G` samples |
| Sync RL | staleness 0, simplest | long-tail idle GPUs | small/debug runs |
| Async RL (AReaL/slime) | 1.5–2.6× faster | staleness 1–20+ versions, needs IS | throughput-bound at scale |
| Partial rollouts (APRIL) | +22.5% tput | resume bookkeeping + IS | heavy long-tail CoT |
| FSDP | portable, simple | lower peak tput | mixed hardware |
| Megatron (TP/PP/EP) | top NVIDIA tput, MoE-ready | steep, NVLink-bound | large MoE on NVLink |
| Batch-invariant kernels | reproducible trajectories | ~34% slower (2.8× back w/ CUDA graphs) | eval parity, debugging drift |
| FP8 inference / BF16 train | 33% faster rollout | FP8 training unstable | rollout-heavy RL |
| GRPO on MoE (no fix) | simplest | collapses (router shift) | don't |
| R3 / GSPO on MoE | stable MoE RL | R3 plumbing; GSPO coarser credit | any MoE policy |
## 6. How it's asked
**Q (IC5): "How many models are in memory during GRPO, and which is most expensive?"**
Policy + reference, plus a reward model only if reward is learned (RLVR verifiers need none) — so 2–3 copies. The **policy** dominates: it alone carries gradients and Adam state (~16 bytes/param mixed precision), while reference/reward are inference-only (~2 bytes/param). Under pressure I cut policy activation memory first (checkpointing/offload), then LoRA-share the backbone, before touching the frozen copies.
**Q (IC5): "Long-tail rollouts are killing throughput. Fixes?"**
Profile to confirm GPUs idle on the slowest decode. Cheapest: length-aware batching (RollPacker) to isolate outliers in dedicated long rounds. Next: partial rollouts (APRIL) with IS correction, ~+22.5%. Most invasive but biggest: go async (AReaL/slime) for 1.5–2.6×, accepting a staleness/IS tax.
**Q (IC6): "Dense→MoE and GRPO collapses. Why and what do you do?"**
MoE routing differs across rollout and training engines (~10% of routers, ~94% of tokens differ in ≥1 layer); each update shifts experts more, so the IS ratio spikes and clipping fires chaotically. Two principled fixes: **R3** (replay rollout routing masks into the trainer — halves routing KL, kills spikes 10×) or **GSPO** (sequence-level importance ratio that averages out transient per-token routing noise). TIS is a coarse stopgap.
**Q (IC6): "Why is your RL run nondeterministic, and does it matter?"**
Kernels are batch-size dependent — reduction order changes with batch size and FP add isn't associative, so rollout-engine and trainer logprobs differ even at step 0. That fake "off-policy" gap pollutes the IS ratio and makes eval comparisons unreliable. Fix with batch-invariant kernels (~34% slower, recoverable via CUDA graphs) when reproducibility matters.
**Q (IC6): "Async gave you 2× — what did you pay?"**
Statistical efficiency. Data is now 1–20+ policy versions stale; IS ratios fat-tail and variance rises. You manage it with truncated/masked IS, staleness-aware weighting (A-3PO), and a buffer sized so the policy doesn't drift faster than weight sync — otherwise you give the 2× back as instability.
## 7. Pitfalls & flashcards
- **Forgetting the policy is the only expensive copy.** "Three models" sounds like 3× memory; it's not — frozen copies are inference-only.
- **Symmetric IS truncation on MoE.** Token-level TIS masks blowups but throws away signal; sequence-level (GSPO) is usually better.
- **Blaming the GPU for nondeterminism.** It's batch-variant kernels, not hardware randomness.
- **Async without staleness handling.** Decoupling engines without IS correction = silent divergence.
- **Std-normalization with a tiny denominator.** All-right or all-wrong groups → near-zero std → exploding advantages; clamp it.
- **FP8 for the gradient step.** Use FP8 on rollout/inference; keep BF16 for training stability.
- **Raw `all_reduce` under sequence parallelism.** Gradients off by the SP factor; use `torch.distributed.nn.all_reduce`.
> Flashcard. **Why does GSPO stabilize MoE RL where GRPO collapses?** Because routing disagreement between the rollout and training engines makes *token-level* importance ratios spike chaotically; GSPO's *sequence-level* ratio `(log π_θ(y) − log π_old(y))/|y|` averages transient per-token routing fluctuations out — no routing replay needed.
> Flashcard. **Sync vs async RL in one line:** sync = zero staleness but GPUs idle on the long tail; async = 1.5–2.6× throughput bought with 1–20+ versions of staleness that you must pay back via importance sampling.
## 8. Further reading
- **AReaL** — fully decoupled async RL, 2.57× speedup, staleness handling: https://arxiv.org/html/2505.24298v4
- **slime (LMSYS)** — SGLang+Megatron post-training, production GLM: https://www.lmsys.org/blog/2025-07-09-slime/
- **SGLang determinism** — batch-invariant kernels for reproducible RL: https://www.lmsys.org/blog/2025-09-22-sglang-deterministic/
- **R3 (routing replay)** — aligning MoE train/inference routers: https://arxiv.org/pdf/2510.11370
- **GSPO** — sequence-level importance sampling: https://arxiv.org/pdf/2507.18071
- **VeRL** — production multi-algorithm RL framework: https://github.com/verl-project/verl
Next: [/finetuning/rl-interview-benchmark](/finetuning/rl-interview-benchmark) to drill these under timed conditions, then [/interview](/interview) for the full loop and [/library](/library) for the algorithm map.