torch.compile recompiled our SDXL UNet 38 times in production
Stabilizing Diffusion Inference: Managing Graph Compilation Guards in Variable-Resolution Pipelines
Current Situation Analysis
Modern AI inference pipelines face a structural tension: static graph optimization promises massive throughput gains, but real-world workloads rarely conform to fixed tensor dimensions. When engineering teams integrate torch.compile into production diffusion systems, they frequently encounter a deceptive performance profile. Benchmarks executed with uniform input shapes report substantial latency reductions, yet production deployments experience severe tail-latency degradation.
The root cause lies in how PyTorch's Dynamo compiler handles input variance. Dynamo traces execution graphs and wraps them in runtime guards that validate tensor shapes, data types, device placement, and scalar constants. When an incoming request violates any guard condition, the compiler discards the cached graph and triggers a fresh compilation cycle. For large architectures like the SDXL UNet, this compilation phase consumes 40 to 90 seconds on an A10G GPU. Crucially, compilation occurs lazily: the first request that breaches a guard pays the full compilation cost inline, directly inflating p99 latency.
In production environments handling user-generated imagery, resolution variance is the norm rather than the exception. Mobile photography, e-commerce exports, and manual cropping produce a continuous distribution of aspect ratios and pixel dimensions. Without explicit shape normalization, each novel resolution triggers a guard failure. Telemetry from live deployments consistently shows dozens of recompilation events within the first hundred requests, with latency spikes exceeding 70 seconds per violation. The problem is rarely the compiler itself; it is the mismatch between static optimization assumptions and dynamic input distributions.
WOW Moment: Key Findings
The critical insight emerges when comparing compilation strategies against real traffic distributions. Raw throughput metrics obscure tail-latency behavior, which is what actually determines user experience and SLA compliance. By constraining the input space and pre-warming compiled graphs, teams can preserve optimization gains while eliminating unpredictable compilation overhead.
| Approach | Recompilation Events | Sustained Speedup | Compute Overhead | Latency Predictability |
|---|---|---|---|---|
dynamic=True |
0 | ~1.6x | Higher per-step kernel cost | High |
| Resolution Bucketing | 3 (boot-time only) | ~2.1x | VAE padding overhead | Very High |
| Fixed Canonical Resolution | 0 | 2.3x | Quality degradation on extreme ratios | Maximum |
Resolution bucketing emerges as the optimal production strategy because it decouples compilation cost from request latency. By mapping arbitrary inputs to a small set of predefined dimensions, the system triggers exactly three compilation cycles during container initialization. All subsequent requests hit cached graphs, eliminating guard failures entirely. The trade-off is padding overhead in the VAE decode stage, but this cost is deterministic and easily quantified, unlike the stochastic latency spikes caused by lazy compilation.
Core Solution
Implementing a stable compilation pipeline requires three coordinated phases: input normalization, graph pre-warming, and orchestration lifecycle alignment. Each phase addresses a specific failure mode in the Dynamo compilation workflow.
Phase 1: Input Normalization via Resolution Bucketing
Instead of allowing arbitrary tensor dimensions, the pipeline maps incoming images to a constrained set of target resolutions. For SDXL, latent space dimensions scale linearly with pixel dimensions divided by 8. Selecting buckets that align with common aspect ratios minimizes padding waste while covering the majority of production traffic.
import torch
import math
from typing import Tuple
class ShapeNormalizer:
def __init__(self, target_resolutions: list[int], latent_scale: int = 8):
self.resolutions = sorted(target_resolutions)
self.latent_scale = latent_scale
def resolve_bucket(self, height: int, width: int) -> Tuple[int, int]:
long_edge = max(height, width)
target = min(self.resolutions, key=lambda r: abs(r - long_edge))
return target, target
def prepare_tensor(self, image_tensor: torch.Tensor) -> torch.Tensor:
h, w = image_tensor.shape[-2], image_tensor.shape[-1]
target_h, target_w = self.resolve_bucket(h, w)
pad_h = target_h - h
pad_w = target_w - w
padded = torch.nn.functional.pad(
image_tensor,
pad=(0, pad_w, 0, pad_h),
mode='constant',
value=0.0
)
latent_h, latent_w = target_h // self.latent_scale, target_w // self.latent_scale
return padded, (latent_h, latent_w)
The normalizer calculates padding requirements and returns both the padded tensor and the corresponding latent dimensions. This abstraction keeps shape resolution logic isolated from model inference code, making it trivial to adjust bucket boundaries without touching the compilation pipeline.
Phase 2: Graph Pre-Warming and Compilation
PyTorch's Inductor compiler caches graphs per process by default. To prevent lazy compilation during production traffic, all target buckets must be compiled before the service accepts requests. This requires executing dummy forward passes with matching shapes and timestep values.
class GraphCompiler:
def __init__(self, model: torch.nn.Module, normalizer: ShapeNormalizer):
self.model = model
self.normalizer = normalizer
self.compiled_model = torch.compile(model, fullgraph=False)
def warmup(self, device: str = "cuda", batch_size: int = 1):
for res in self.normalizer.resolutions:
latent_dim = res // self.normalizer.latent_scale
dummy_latent = torch.zeros(
batch_size, 4, latent_dim, latent_dim, device=device
)
dummy_timestep = torch.tensor([999.0], device=device)
with torch.no_grad():
_ = self.compiled_model(dummy_latent, dummy_timestep)
print(f"Pre-warmed {len(self.normalizer.resolutions)} resolution buckets")
The warmup routine iterates through each target resolution, constructs a zero-initialized latent tensor matching SDXL's channel configuration, and executes a forward pass. The fullgraph=False flag allows partial compilation when certain operations lack Inductor support, preventing compilation failures while still optimizing the majority of the UNet graph.
Phase 3: Orchestration Lifecycle Alignment
Container orchestration platforms route traffic to pods based on readiness signals. If the compilation warmup completes after the readiness probe passes, the first production requests will trigger lazy compilation. The readiness probe must account for the full warmup duration.
# Kubernetes readiness probe configuration
readinessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 240
periodSeconds: 10
failureThreshold: 3
Setting initialDelaySeconds to 240 ensures the pod remains in a pending state while the compiler finishes pre-warming. This adds approximately 4 minutes to startup time but guarantees that all compiled graphs are resident in memory before traffic routing begins. The trade-off is slower autoscaling response during traffic spikes, which is acceptable for workloads where tail latency stability outweighs rapid scale-out requirements.
Pitfall Guide
1. Assuming dynamic=True Eliminates Compilation Overhead
Explanation: Enabling dynamic shapes tells Dynamo to generate generalized kernels that handle variable dimensions. While this prevents recompilation, the resulting kernels sacrifice optimization opportunities like loop unrolling and memory coalescing. Throughput typically drops to ~1.6x compared to eager mode, leaving significant performance on the table. Fix: Use dynamic compilation only for workloads with extreme shape variance where padding overhead exceeds kernel generalization penalties. Otherwise, constrain inputs to fixed buckets.
2. Skipping Pre-Warmup Before Readiness
Explanation: Lazy compilation defers graph generation until the first request that violates a guard. In production, this injects 40-90 seconds of compilation time directly into user-facing latency. The first request after deployment or scaling event will always suffer. Fix: Execute dummy forward passes for all target shapes during container initialization. Block readiness signals until warmup completes.
3. Trusting Shared Inductor Cache Across Nodes
Explanation: Setting TORCHINDUCTOR_CACHE_DIR to a shared volume appears to solve per-process recompilation, but the cache key includes environment metadata, CUDA driver versions, and PyTorch build hashes. Mismatched versions across nodes cause silent fallback to eager execution, eliminating all compilation benefits without raising errors.
Fix: Validate cache hits using TORCH_LOGS="output_code" or monitor compilation frequency. Prefer per-pod warmup with optimized startup scripts over shared caching in heterogeneous clusters.
4. Ignoring VAE Decode Padding Overhead
Explanation: Padding images to fixed buckets increases pixel count downstream. An 800x600 image resized to 1024x1024 pushes approximately 60% additional pixels through the VAE decoder. This overhead compounds if multiple pipeline stages process padded tensors. Fix: Profile the full pipeline before optimizing. If VAE decode dominates wall-clock time, reduce bucket granularity or implement crop-aware normalization that preserves aspect ratio within bucket constraints.
5. Misaligning Kubernetes Readiness Probes
Explanation: Default readiness probes often check HTTP endpoints within 10-30 seconds. If compilation warmup requires 4 minutes, the pod accepts traffic before graphs are ready, triggering lazy compilation and latency spikes.
Fix: Configure initialDelaySeconds to exceed the maximum warmup duration. Use a dedicated /warmup endpoint that returns 200 only after all buckets are compiled, then transition to standard health checks.
6. Over-Segmenting Resolution Buckets
Explanation: Creating dozens of buckets to minimize padding waste increases memory consumption and extends warmup time. Each bucket requires a separate compiled graph, consuming VRAM and prolonging initialization. Fix: Limit buckets to 3-5 resolutions that cover >90% of production traffic. Analyze historical image dimension distributions to select optimal boundaries.
7. Profiling After Compilation
Explanation: Teams frequently apply torch.compile to the entire pipeline without identifying the actual bottleneck. If the UNet represents only 40% of total latency, compilation gains will be marginal regardless of guard configuration.
Fix: Run torch.profiler or nvprof on the uncompiled pipeline first. Apply compilation exclusively to the dominant compute stage. Validate that the targeted component actually benefits from graph optimization.
Production Bundle
Action Checklist
- Analyze production traffic distribution to identify dominant image resolutions and aspect ratios
- Define 3-5 resolution buckets that cover >90% of incoming requests while minimizing padding waste
- Implement a shape normalizer that maps arbitrary inputs to target buckets with deterministic padding
- Configure
torch.compilewithfullgraph=Falseand attach it to the primary compute stage - Execute dummy forward passes for all buckets during container initialization before accepting traffic
- Adjust Kubernetes
initialDelaySecondsto exceed total warmup duration (typically 180-240s) - Enable
TORCH_LOGS="recompiles"in staging to verify guard stability before production rollout - Monitor p99 latency and compilation frequency post-deployment to detect regression
Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
|---|---|---|---|
| High variance user uploads, strict p99 SLA | Resolution bucketing + pre-warmup | Eliminates lazy compilation, deterministic padding overhead | +4 min startup, ~15-20% VAE compute increase |
| Fixed resolution API, controlled inputs | Fixed canonical resolution | Maximum speedup (2.3x), zero recompilation risk | Quality loss on non-standard aspect ratios |
| Extreme shape diversity, memory constrained | dynamic=True compilation |
Prevents recompilation, avoids padding overhead | ~30% throughput reduction vs static graphs |
| Heterogeneous GPU fleet, shared storage | Per-pod warmup with validation | Avoids silent eager fallback from cache key mismatches | Higher CPU usage during startup, no shared cache benefit |
| VAE decode dominates latency | Crop-aware normalization + smaller buckets | Reduces padding multiplier in downstream stages | Requires custom preprocessing pipeline |
Configuration Template
import os
import torch
from typing import List
# Environment configuration
os.environ["TORCH_LOGS"] = "recompiles"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/inductor_cache"
class InferencePipeline:
def __init__(self, model: torch.nn.Module, buckets: List[int]):
self.normalizer = ShapeNormalizer(buckets)
self.compiler = GraphCompiler(model, self.normalizer)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
def initialize(self):
"""Must be called before serving traffic"""
self.compiler.warmup(device=self.device)
def process(self, image_tensor: torch.Tensor) -> torch.Tensor:
padded, latent_shape = self.normalizer.prepare_tensor(image_tensor)
padded = padded.to(self.device)
timestep = torch.tensor([999.0], device=self.device)
with torch.no_grad():
output = self.compiler.compiled_model(padded, timestep)
return output
# Deployment entrypoint
if __name__ == "__main__":
TARGET_BUCKETS = [768, 1024, 1280]
pipeline = InferencePipeline(model=load_sdxl_unet(), buckets=TARGET_BUCKETS)
pipeline.initialize()
start_http_server(port=8080, health_endpoint="/health")
Quick Start Guide
- Define Resolution Buckets: Analyze your production image dimension logs. Select 3-5 target resolutions that cover the majority of traffic. Common SDXL buckets:
[768, 1024, 1280]. - Integrate Normalizer: Replace arbitrary tensor resizing with the
ShapeNormalizerclass. Ensure padding logic aligns with your VAE's expected input format. - Add Warmup Routine: Call the pre-warmup method during container startup. Block HTTP server initialization until all buckets are compiled.
- Configure Orchestration: Set
initialDelaySecondsto 240 in your readiness probe. Deploy to a staging environment and verifyTORCH_LOGS="recompiles"shows zero guard failures after warmup. - Monitor & Iterate: Track p99 latency and compilation frequency in production. Adjust bucket boundaries if padding overhead exceeds acceptable thresholds or if new resolution clusters emerge.
Mid-Year Sale β Unlock Full Article
Base plan from just $4.99/mo or $49/yr
Sign in to read the full article and unlock all tutorials.
Sign In / Register β Start Free Trial7-day free trial Β· Cancel anytime Β· 30-day money-back
