Inference vs Training Compute: FLOPs per Token vs Total Training Cost
Training FLOPs ≈ 6·N·D for dense transformers (N parameters, D tokens); inference ≈ 2·N FLOPs per token; a 70B model requires ~1.4×10¹¹ FLOPs per token vs ~5.9×10²³ total training FLOPs (Chinchilla-optimal); training = ~4.2 trillion inference tokens equivalent.
| Measure | Value | Unit | Notes |
|---|---|---|---|
| Training FLOPs approximation | C_train ≈ 6·N·D | FLOPs | N = parameters, D = tokens; factor 6 = 2 (forward) + 4 (backward + optimizer step) |
| Inference FLOPs per token (with KV cache) | C_infer ≈ 2·N | FLOPs/token | Two matrix multiplications per token; no backward pass; past tokens cached, no recomputation |
| Training-equivalent inference tokens | C_train / C_infer = 3·D | inference tokens | 6·N·D training FLOPs ÷ 2·N per inference token = 3D inference tokens equivalent |
| 70B Chinchilla training FLOPs | ~5.9×10²³ | FLOPs | 70B params × 1.4T tokens × 6 ≈ 5.9×10²³; requires ~1000 A100-days at peak efficiency |
| 70B inference FLOPs per token | ~1.4×10¹¹ | FLOPs | 2 × 70×10⁹ = 1.4×10¹¹; A100 at 312 TFLOPS processes ~2000 tokens/s for a single request |
Understanding the compute split between training and inference is fundamental to reasoning about large language model economics, energy costs, and deployment strategies. Training and inference have structurally different cost profiles: training is a one-time fixed cost, while inference cost scales linearly with deployment usage.
FLOPs Breakdown: Training
For a dense transformer with N parameters trained on D tokens:
C_train ≈ 6 · N · D
The factor 6 decomposes as:
- 2×: forward pass (compute activations, loss)
- 2×: backward pass (compute gradients via backpropagation)
- 2×: optimizer step (Adam: update first and second moment estimates per parameter)
FLOPs Breakdown: Inference
For a single generated token (assuming full KV cache):
C_infer ≈ 2 · N
The factor 2 covers: attention projection (Q, K, V matrices) and feed-forward network projection at each layer. No backward pass; no optimizer state maintained.
Compute at Scale
| Model Size (N) | Training Tokens (D) | Total Training (FLOPs) | Inference/Token (FLOPs) | Training ≡ N_inf Inference Tokens |
|---|---|---|---|---|
| 1B | 20B (Chinchilla) | 1.2×10²⁰ | 2×10⁹ | ~6×10¹⁰ |
| 7B | 140B (Chinchilla) | 5.9×10²¹ | 1.4×10¹⁰ | ~4.2×10¹¹ |
| 70B | 1.4T (Chinchilla) | 5.9×10²³ | 1.4×10¹¹ | ~4.2×10¹² |
Batch Size and Inference Efficiency
| Batch Size | Arithmetic Intensity | Throughput | Bottleneck |
|---|---|---|---|
| 1 (latency-optimized) | ~1 FLOP/byte | Low tokens/s | Memory bandwidth |
| 32 | ~32 FLOP/byte | Moderate | Mixed |
| 512+ | High | Near-peak FLOP/s | Compute |
At batch size 1, inference throughput is ~1000× below GPU peak compute due to memory bandwidth constraints. Batching requests amortizes the weight loading cost across multiple tokens.
Prefill vs Decode Phases
Inference has two distinct phases:
| Phase | Input | FLOPs | Parallelism |
|---|---|---|---|
| Prefill (prompt processing) | All input tokens at once | 2·N·T_prompt | Full parallelism |
| Decode (token generation) | One token per step | 2·N per token | Sequential |
The prefill phase computes at high arithmetic intensity and is GPU compute-bound. The decode phase generates sequentially and is memory-bandwidth bound unless batched.
See compute-flops for detailed FLOP counting methodology, kv-cache for how caching eliminates redundant prefill recomputation, and quantization for reducing the memory bottleneck in decode.
Related Pages
Sources
- Kaplan et al. (2020) — Scaling Laws for Neural Language Models. arXiv
- Hoffmann et al. (2022) — Training Compute-Optimal Large Language Models (Chinchilla). NeurIPS 2022
- Patterson et al. (2021) — Carbon and the Broad Landscape of AI. arXiv
Frequently Asked Questions
Why is the backward pass approximately 2× more expensive than the forward pass?
The forward pass computes activations and loss in one traversal through the network. Backpropagation must apply the chain rule in reverse: one pass computes gradients of the loss w.r.t. activations (∂L/∂a), another computes gradients w.r.t. weights (∂L/∂W). The Adam optimizer step adds computation for updating first and second moment estimates per parameter. Together, backward + optimizer ≈ 4× forward pass compute, giving a training total of ~6× forward pass per step.
Why is inference memory-bandwidth bound at small batch sizes?
For a single request (batch size 1), generating one token requires loading all N model parameters from GPU HBM memory to compute 2·N FLOPs. The arithmetic intensity is 2·N FLOPs / (2N bytes for FP16) = 1 FLOP/byte — far below the A100's compute-to-memory ratio (~300:1). The GPU spends most time waiting for memory transfers, not computing. Larger batch sizes improve arithmetic intensity by amortizing the weight loads over multiple simultaneous requests.