Separation Architecture
Prefill and decode exhibit fundamentally different performance characteristics. Prefill processes the entire prompt in parallel, utilizing large matrix multiplications that saturate compute units. Decode generates tokens sequentially, reading the growing KV cache from HBM, making it memory-bandwidth-bound.
class InferencePipeline:
def __init__(self, model: TransformerModel, cache_config: CacheConfig):
self.model = model
self.cache_allocator = CacheAllocator(cache_config)
self.decode_scheduler = DecodeScheduler()
def execute(self, prompt_ids: torch.Tensor, max_tokens: int) -> torch.Tensor:
# Phase 1: Prefill (Compute-bound)
cache_state = self._run_prefill(prompt_ids)
# Phase 2: Decode (Memory-bound)
generated = self._run_decode(cache_state, max_tokens)
return generated
def _run_prefill(self, prompt_ids: torch.Tensor) -> CacheState:
batch_size, seq_len = prompt_ids.shape
cache_state = self.cache_allocator.allocate(batch_size, seq_len)
with torch.no_grad():
hidden = self.model.embed(prompt_ids)
for layer in self.model.layers:
q, k, v = layer.project(hidden)
cache_state.store_kv(layer.idx, k, v)
hidden = layer.attention(q, cache_state.get_kv(layer.idx))
return cache_state
Step 2: Dynamic Cache Management
Instead of pre-allocating contiguous blocks per request, implement a page-based allocator. Each page holds a fixed number of token slots (e.g., 16 tokens). Logical sequences map to non-contiguous physical pages, eliminating fragmentation.
class PageAllocator:
def __init__(self, page_size: int = 16, num_layers: int = 32):
self.page_size = page_size
self.num_layers = num_layers
self.free_pages: List[torch.Tensor] = []
self.page_table: Dict[int, List[int]] = {} # request_id -> [page_ids]
def allocate_for_request(self, request_id: int, seq_len: int) -> List[int]:
required_pages = math.ceil(seq_len / self.page_size)
allocated = []
for _ in range(required_pages):
if not self.free_pages:
self._grow_pool()
allocated.append(self.free_pages.pop())
self.page_table[request_id] = allocated
return allocated
def get_kv_tensors(self, request_id: int, layer_idx: int) -> torch.Tensor:
pages = self.page_table[request_id]
layer_tensors = [self._extract_layer(page, layer_idx) for page in pages]
return torch.cat(layer_tensors, dim=1)
Step 3: Decode Loop with Cache Appending
During decode, only the new token's Query, Key, and Value are computed. The new K/V are appended to the appropriate cache page, and attention is computed against the full cached history.
def _run_decode(self, cache_state: CacheState, max_tokens: int) -> torch.Tensor:
output_tokens = []
current_seq_len = cache_state.seq_len
for _ in range(max_tokens):
with torch.no_grad():
# Compute only for the new position
new_hidden = self.model.embed(cache_state.last_token)
for layer in self.model.layers:
q_new, k_new, v_new = layer.project(new_hidden)
cache_state.append_kv(layer.idx, k_new, v_new)
new_hidden = layer.attention(q_new, cache_state.get_kv(layer.idx))
logits = self.model.lm_head(new_hidden[:, -1, :])
next_token = torch.argmax(logits, dim=-1, keepdim=True)
output_tokens.append(next_token)
cache_state.last_token = next_token
return torch.cat(output_tokens, dim=-1)
Architecture Rationale
- Phase Separation: Prefill uses batched matrix ops for maximum compute utilization. Decode uses incremental updates to minimize memory reads. This matches the hardware characteristics of modern GPUs.
- Page-Based Allocation: Borrowed from OS virtual memory, paging decouples logical sequence length from physical memory layout. It enables continuous batching without wasting HBM on alignment padding.
- Layer-Parallel KV Storage: Storing K/V per layer rather than per token reduces indexing overhead during attention computation. It aligns with how transformer kernels access memory.
- Deterministic Memory Growth: Cache size scales predictably:
batch_size × seq_len × num_layers × num_kv_heads × head_dim × dtype_bytes. This enables accurate capacity planning before deployment.
Pitfall Guide
1. Treating Prefill and Decode as Identical Workloads
Explanation: Prefill is compute-bound; decode is memory-bandwidth-bound. Applying the same batching strategy or kernel configuration to both phases causes severe underutilization.
Fix: Implement separate execution paths. Use large batch sizes and fused attention kernels for prefill. Use smaller, dynamic batches and optimized memory access patterns for decode.
2. Ignoring Memory Fragmentation in Dynamic Batching
Explanation: Static contiguous allocation fails when requests have variable lengths. As sequences finish, freed memory blocks become unusable for longer incoming requests, causing artificial OOM errors.
Fix: Adopt virtualized memory management (PagedAttention). Map logical token positions to physical pages. Implement a page pool with LRU eviction for long-running services.
3. Unbounded Cache Growth (Context Window Leaks)
Explanation: Failing to enforce context limits causes cache allocation to exceed GPU memory. Long-running sessions or malformed prompts can trigger cascading OOM failures.
Fix: Implement hard context limits at the API gateway. Use sliding window attention or cache eviction policies (e.g., keep recent tokens + attention sinks) for extended sessions.
4. Misaligned Batch Dimensions During Decoding
Explanation: When requests finish at different steps, batch dimensions shrink. Naive implementations reallocate tensors or pad with zeros, introducing kernel launch overhead and memory churn.
Fix: Use continuous batching with dynamic request scheduling. Maintain a fixed maximum batch size and swap completed requests with waiting ones without reallocating cache structures.
5. Overlooking Memory Bandwidth Saturation
Explanation: Engineers optimize for FLOPS but ignore that decode steps are limited by HBM read speed. Adding more compute layers or increasing head dimension without bandwidth planning yields diminishing returns.
Fix: Profile memory bandwidth utilization using nsys or nvprof. Consider KV cache quantization (int8/fp8) or grouped-query attention to reduce memory reads per decode step.
6. Hardcoding Cache Allocation Without Runtime Profiling
Explanation: Static cache sizes based on theoretical maximums waste memory during low-load periods and crash during spikes.
Fix: Implement adaptive cache sizing based on real-time request queue depth and average sequence length. Use memory pressure metrics to trigger graceful degradation or request queuing.
7. Neglecting Quantization Trade-offs
Explanation: Quantizing KV cache to int8 reduces memory footprint by 50% but introduces precision loss that can degrade generation quality, especially for long contexts.
Fix: Use per-token or per-channel quantization with dynamic scaling factors. Validate perplexity degradation on domain-specific benchmarks before production rollout. Consider mixed-precision caching (fp16 for recent tokens, int8 for older).
Production Bundle
Action Checklist
Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
|---|
| High-throughput batch processing (offline) | Static batching + Standard KV Cache | Predictable workloads allow contiguous allocation; simpler implementation | Low infrastructure cost, moderate latency |
| Real-time chat API with variable lengths | PagedAttention + Continuous Batching | Eliminates fragmentation; maintains high GPU utilization across mixed request sizes | Higher engineering cost, optimal TCO |
| Long-context research/workflows (32K+) | Sliding Window + KV Quantization (int8) | Prevents OOM; reduces memory bandwidth pressure; maintains acceptable quality | Moderate quality trade-off, significant memory savings |
| Low-latency edge deployment | Speculative Decoding + Small KV Cache | Reduces decode steps via draft model verification; minimizes memory footprint | Higher compute cost, lower latency |
Configuration Template
inference_engine:
model: "meta-llama/Llama-2-13b-hf"
dtype: "fp16"
tensor_parallel_size: 2
cache:
backend: "paged"
page_size: 16
max_context_length: 8192
eviction_policy: "sliding_window"
window_size: 4096
quantization:
enabled: false
dtype: "int8"
per_token_scaling: true
scheduling:
strategy: "continuous_batching"
max_batch_size: 32
prefill_chunk_size: 512
decode_micro_batch: 8
monitoring:
metrics:
- "cache_hit_rate"
- "hbm_bandwidth_utilization"
- "ttft_p99"
- "tokens_per_second"
alerting:
cache_fragmentation_threshold: 0.15
hbm_saturation_threshold: 0.85
Quick Start Guide
- Initialize the Cache Allocator: Deploy a page-based memory manager with a page size matching your attention kernel's optimal tile size (typically 16 or 32 tokens). Configure the page pool to pre-allocate 70% of available GPU memory, reserving 30% for weights and activations.
- Configure Phase-Specific Kernels: Enable fused attention for prefill to maximize compute throughput. Switch to incremental attention kernels for decode that read only cached K/V and compute Q for the new token. Set
prefill_chunk_size to 512 to balance memory pressure and TTFT.
- Enable Continuous Batching: Replace static request queues with a dynamic scheduler that swaps completed requests with waiting ones at every decode step. Monitor batch utilization; aim for >80% active slots to amortize memory read costs.
- Validate Memory Boundaries: Run load tests with mixed sequence lengths (256 to 8192 tokens). Track cache fragmentation ratio and HBM bandwidth utilization. Adjust
max_batch_size and page_size until fragmentation stays below 5% and bandwidth utilization peaks at 80-85%.
- Deploy Monitoring & Fallbacks: Instrument cache hit rates, TTFT, and tokens-per-second. Configure automatic request queuing when cache pressure exceeds thresholds. Implement graceful degradation (e.g., sliding window eviction) before OOM triggers occur.