KV Cache: Key-Value Caching for Efficient Autoregressive Inference
KV caching stores key-value pairs from previous tokens, reducing inference FLOPs per step from O(n·d²) to O(d²); cache size for a 32-layer, 32-head, d_head=128 model at 4K tokens is ~536 MB at fp16.
| Measure | Value | Unit | Notes |
|---|---|---|---|
| KV cache memory formula | 2 × n_layers × n_heads × d_head × seq_len × dtype_bytes | bytes | Factor 2 for keys + values; doubles per additional layer and head |
| Example: 32 layers, 32 heads, d_head=128, 4096 tokens, fp16 | 536 | MB | 2 × 32 × 32 × 128 × 4096 × 2 = 536,870,912 bytes ≈ 536 MB |
| Inference FLOPs without KV cache (token t) | O(t·d²) | Must recompute all previous K, V from scratch for each new token | |
| Inference FLOPs with KV cache (token t) | O(d²) | Only compute K, V for the new token; attend over cached K, V for all prior tokens | |
| Multi-Query Attention KV size reduction | 8–32× | reduction | MQA/GQA uses 1 or G < h KV heads shared across query heads; reduces KV cache proportionally |
The KV cache is the mechanism that makes autoregressive transformer inference computationally feasible. Without caching, generating a 1,000-token sequence would require 1,000 separate full-attention computations. With caching, each new token generation requires only a single new key-value computation per layer, with attention over the stored cache.
How Autoregressive Inference Works
Without KV cache — generating token t+1:
- Concatenate all tokens [x₁, x₂, …, x_t, x_{t+1}]
- Run full forward pass through all layers
- Read logits at position t+1
- Redundancy: K and V for positions 1..t were computed identically in the previous step
With KV cache — generating token t+1:
- Compute K and V only for new token x_{t+1}
- Append to cache: K_cache ← [K_cache; K_{t+1}], V_cache ← [V_cache; V_{t+1}]
- Compute attention using x_{t+1}‘s query against full K_cache and V_cache
- Total FLOPs per step: O(t·d) for attention + O(d²) for linear layers
KV Cache Memory Breakdown
For a model with n_layers=32, n_heads=32, d_head=128:
| Sequence Length | KV Cache Size (fp16) | KV Cache Size (int8) |
|---|---|---|
| 512 | 67 MB | 33 MB |
| 2,048 | 268 MB | 134 MB |
| 4,096 | 536 MB | 268 MB |
| 32,768 | 4.3 GB | 2.1 GB |
| 131,072 | 17.2 GB | 8.6 GB |
KV Cache Reduction Techniques
| Technique | KV Heads | Cache Reduction | Quality Impact |
|---|---|---|---|
| Multi-Head Attention (MHA) | h per layer | 1× (baseline) | Full quality |
| Multi-Query Attention (MQA) | 1 per layer | h× | Minor quality loss |
| Grouped Query Attention (GQA) | G per layer (G<h) | h/G × | Near-MHA quality |
| PagedAttention | — | No size change | Reduces fragmentation |
| KV cache quantization | — | 2–4× | <1% quality loss |
GQA (Ainslie et al., 2023) with G=8 groups reduces cache size 4× (for h=32) while maintaining nearly full MHA quality, making it the dominant approach in efficient inference.
See context-window for how sequence length interacts with memory, quantization for KV cache precision reduction, and inference-vs-training-compute for how KV caching affects overall inference compute budgets.
Related Pages
Sources
- Vaswani et al. (2017) — Attention Is All You Need. NeurIPS 2017
- Pope et al. (2023) — Efficiently Scaling Transformer Inference. MLSys 2023
- Ainslie et al. (2023) — GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023
Frequently Asked Questions
Why is a KV cache necessary for autoregressive inference?
During autoregressive generation, the model generates one token at a time. Without caching, to generate token t it would need to recompute the attention keys and values for all t−1 previous tokens in every layer — scaling as O(t) per new token. With the KV cache, keys and values are computed once and stored; generating each new token only adds O(1) new KV pairs, making the per-token inference cost constant rather than growing with sequence length.
How large does the KV cache get in practice?
KV cache size = 2 × n_layers × n_heads × d_head × seq_len × bytes_per_element. For a medium-scale model (32 layers, 32 heads, d_head=128) running at fp16 with a 4K token context, the KV cache is ~536 MB. At 128K tokens, the same model requires ~17 GB of KV cache alone — often exceeding the weight memory at short contexts. This is why long-context inference requires careful memory management.
What is Multi-Query Attention and how does it reduce KV cache?
Standard multi-head attention (MHA) maintains separate K and V projections for each of the h heads. Multi-Query Attention (MQA, Shazeer 2019) uses a single shared K and V projection across all query heads, reducing KV cache size by factor h. Grouped Query Attention (GQA, Ainslie et al. 2023) is a middle ground with G groups (G < h shared KV heads), reducing cache size by h/G while retaining most of MHA's quality.