roducibility).
Step 1: Teacher Freezing & Dataset Synthesis
Generate training data offline. Run the teacher model on a curated instruction dataset and cache logits, hidden states, and attention patterns. This eliminates teacher inference during training and guarantees dataset consistency.
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
class TeacherSynthesizer:
def __init__(self, teacher_path: str, device: str = "cuda"):
self.tokenizer = AutoTokenizer.from_pretrained(teacher_path)
self.model = AutoModelForCausalLM.from_pretrained(
teacher_path, torch_dtype=torch.float16, device_map=device
).eval()
self.model.requires_grad_(False)
@torch.no_grad()
def synthesize(self, prompts: list[str], max_new_tokens: int = 512) -> dict:
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
temperature=0.8
)
return {
"input_ids": inputs.input_ids,
"generated_ids": outputs.sequences,
"logits": outputs.logits,
"hidden_states": outputs.hidden_states
}
Step 2: Student Initialization & Projection Alignment
Students rarely match teacher layer counts. Insert projection heads to align intermediate dimensions before feature matching. Freeze the student’s base weights initially, then unfreeze progressively.
class ProjectionHead(nn.Module):
def __init__(self, teacher_dim: int, student_dim: int):
super().__init__()
self.linear = nn.Linear(teacher_dim, student_dim)
self.norm = nn.LayerNorm(student_dim)
def forward(self, x):
return self.norm(self.linear(x))
Step 3: Multi-Objective Loss Design
Combine KL divergence for output distribution alignment with MSE for hidden state matching. Apply temperature scaling to soften teacher logits, preventing student overconfidence.
class DistillationLoss(nn.Module):
def __init__(self, temperature: float = 2.0, alpha: float = 0.7):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.kl = nn.KLDivLoss(reduction="batchmean")
self.mse = nn.MSELoss()
def forward(self, student_logits, teacher_logits, student_hidden, teacher_hidden):
# Logit distillation
T = self.temperature
student_soft = torch.log_softmax(student_logits / T, dim=-1)
teacher_soft = torch.softmax(teacher_logits / T, dim=-1)
loss_logits = self.kl(student_soft, teacher_soft) * (T ** 2)
# Feature distillation
loss_features = self.mse(student_hidden, teacher_hidden)
return self.alpha * loss_logits + (1 - self.alpha) * loss_features
Step 4: Training Loop with Gradient Accumulation
Local hardware requires memory optimization. Use gradient accumulation, mixed precision, and activation checkpointing. Decouple batch generation from forward passes to avoid OOM errors.
def train_distillation(
student_model, teacher_synthesizer, dataloader, optimizer,
loss_fn, device, accum_steps=4, max_steps=1000
):
student_model.train()
scaler = torch.amp.GradScaler("cuda")
step = 0
for batch in dataloader:
inputs = {k: v.to(device) for k, v in batch.items()}
with torch.amp.autocast("cuda"):
student_out = student_model(**inputs, output_hidden_states=True)
teacher_cached = teacher_synthesizer.cache_lookup(inputs["input_ids"])
loss = loss_fn(
student_out.logits, teacher_cached["logits"],
student_out.hidden_states[-1][:, -1], teacher_cached["hidden_states"][:, -1]
)
loss = loss / accum_steps
scaler.scale(loss).backward()
if (step + 1) % accum_steps == 0:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
step += 1
if step >= max_steps:
break
Architecture Decisions & Rationale
- Decoupled synthesis: Prevents teacher drift and reduces VRAM pressure. Cached tensors are stored in NVMe-backed memory pools.
- Temperature scaling: Softens teacher probability distributions, enabling the student to learn decision boundaries rather than memorizing hard labels.
- Projection heads: Resolve dimension mismatch without forcing architectural symmetry. LayerNorm stabilizes gradient flow.
- Gradient accumulation + AMP: Simulates larger batch sizes within VRAM constraints. FP16/FP8 reduces memory bandwidth bottlenecks.
- Progressive unfreezing: Freezes lower layers initially to preserve foundational knowledge, then unfreezes top layers for task adaptation.
Pitfall Guide
- Architectural mismatch without projection layers: Directly comparing hidden states of different dimensions causes gradient explosion. Always insert linear projection + normalization before feature matching.
- Ignoring temperature scaling: Fixed temperature (T=1) forces the student to mimic hard labels, eliminating the benefit of knowledge transfer. T=2.0–4.0 is optimal for LLM distillation.
- Over-regularizing the student: Shrinking model depth/width beyond 15–20% of teacher capacity causes representational collapse. Maintain sufficient FFN and attention head counts.
- Teacher data contamination: Using the same dataset for teacher pretraining and distillation creates distribution shift when deployed. Inject out-of-domain samples during synthesis to improve generalization.
- Skipping out-of-distribution validation: Distilled models excel on in-distribution tasks but fail on novel reasoning patterns. Validate on GSM8K, Big-Bench Hard, and domain-specific benchmarks before deployment.
- Not freezing teacher gradients: Accidental teacher parameter updates corrupt cached logits and break reproducibility. Always call
.requires_grad_(False) and verify with sum(p.grad is not None for p in teacher.parameters()).
- Treating distillation as one-shot: Single-stage distillation leaves 10–15% capability on the table. Progressive distillation (stage 1: logit, stage 2: feature, stage 3: self-distillation) closes the gap.
Production Best Practices:
- Align tokenizers exactly between teacher and student. Mismatched vocabularies cause silent accuracy loss.
- Monitor KL divergence stability. Spikes indicate temperature misconfiguration or batch corruption.
- Use LoRA adapters on the student for task-specific fine-tuning post-distillation. This preserves distilled knowledge while adapting to vertical domains.
- Implement tensor parallelism only when student size exceeds single-GPU limits. Distilled 1B–7B models run efficiently on single nodes.
Production Bundle
Action Checklist
Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
|---|
| Edge deployment (<8GB VRAM) | Logit-only distillation (1B–2B student) | Minimizes compute and memory; sufficient for classification/generation tasks | 90% reduction in cloud inference costs |
| Low-latency API (<20ms/token) | Hybrid logit + feature distillation (3B–7B student) | Balances reasoning capability with throughput; projection heads add <2% latency | 60% reduction in GPU fleet size |
| Multi-modal fine-tuning | Relation-based distillation + adapter merging | Captures cross-modal dependencies; relation loss preserves structural alignment | 40% increase in training time, 70% reduction in data labeling costs |
| Budget-constrained training | Progressive distillation with LoRA | Stages reduce peak VRAM; LoRA avoids full weight updates | 85% reduction in training compute vs full fine-tuning |
Configuration Template
distillation:
teacher:
model_id: "meta-llama/Meta-Llama-3.1-70B-Instruct"
cache_dir: "/data/teacher_cache"
synthesis:
max_new_tokens: 512
temperature: 0.8
batch_size: 16
student:
model_id: "meta-llama/Llama-3.2-3B-Instruct"
projection:
teacher_dim: 8192
student_dim: 3072
use_layernorm: true
training:
loss:
type: "hybrid"
temperature: 2.5
alpha: 0.75
optimizer:
type: "adamw"
lr: 2e-5
weight_decay: 0.01
schedule:
warmup_steps: 500
max_steps: 8000
memory:
gradient_accumulation: 4
mixed_precision: "fp16"
activation_checkpointing: true
validation:
benchmarks: ["mmlu", "gsm8k", "alpaca_eval"]
eval_frequency: 1000
early_stopping:
patience: 3
metric: "loss"
Quick Start Guide
- Install dependencies:
pip install transformers torch accelerate peft datasets
- Generate teacher cache: Run
TeacherSynthesizer on your instruction dataset and save tensors to disk using torch.save().
- Initialize student with projection heads: Load the student model, insert
ProjectionHead modules at matching layer indices, and freeze base weights.
- Launch training: Execute the distillation loop with AMP, gradient accumulation, and the hybrid loss function. Monitor KL divergence and validation loss.
- Export & deploy: Save the distilled checkpoint, convert to GGUF/ONNX if needed, and serve locally using vLLM or Ollama. Verify throughput and accuracy against baseline metrics.