A Mechanistic Investigation of Supervised Fine Tuning
Diagnosing Representational Shift in Fine-Tuned Models Using Sparse Autoencoders
Current Situation Analysis
Engineering teams evaluating the impact of Supervised Fine-Tuning (SFT) frequently rely on dense vector metrics to assess representational stability. The standard practice involves computing the cosine similarity between hidden activations of the base model and the fine-tuned model across a validation dataset. A high cosine similarity score is typically interpreted as evidence that the fine-tuning process preserved the model's underlying knowledge structure, suggesting minimal catastrophic forgetting or geometric distortion.
This reliance on dense similarity metrics is a critical blind spot. Cosine similarity measures the angular alignment of high-dimensional vectors but fails to capture changes in the sparse, semantic composition of those vectors. Research indicates that while the cosine similarity between base and SFT model activations remains exceptionally high, the underlying feature representations diverge significantly when analyzed through a mechanistic lens.
The misconception arises because dense embeddings can maintain directional similarity even when the constituent features driving those directions have been fundamentally altered. SFT may rewire the model to prioritize different semantic concepts while keeping the aggregate vector orientation similar. Without a high-resolution diagnostic tool, engineers cannot distinguish between a model that has retained its reasoning capabilities and one that has shifted to superficial pattern matching or safety-compliant heuristics.
Sparse Autoencoders (SAEs) pretrained on the base model provide the necessary resolution to detect these shifts. By projecting activations through a fixed, interpretable dictionary of features, SAEs reveal that SFT induces substantial divergence in sparse latents, even when dense metrics suggest stability. This divergence is not random; it exhibits task-specific and layer-specific distributions, indicating that fine-tuning systematically targets and alters precise semantic features rather than uniformly perturbing the model.
WOW Moment: Key Findings
The discrepancy between dense and sparse metrics reveals a hidden layer of model behavior that standard evaluation pipelines miss. The following comparison illustrates the divergence detected when moving from aggregate vector similarity to feature-level analysis.
| Metric Category | Evaluation Method | Base vs. SFT Model Result | Interpretation |
|---|---|---|---|
| Dense Geometry | Cosine Similarity | > 0.96 | Suggests minimal change; model geometry appears preserved. |
| Sparse Latents | SAE Feature Overlap | < 0.45 | Reveals significant divergence; underlying features have shifted. |
| Feature Magnitude | L1 Divergence | High | Indicates substantial redistribution of activation energy across features. |
| Layer Profile | Layer-wise Analysis | Non-uniform | Safety alignment shows distinct update patterns compared to task-specific layers. |
Why This Matters: The finding that sparse latents diverge significantly despite high cosine similarity enables engineers to detect representational drift that dense metrics mask. This capability is essential for:
- Safety Verification: Identifying if safety fine-tuning has inadvertently suppressed critical reasoning features or introduced brittle heuristics.
- Task Specificity: Pinpointing exactly which semantic features are modified during domain adaptation, allowing for targeted interventions.
- Model Diagnostics: Moving beyond black-box evaluation to mechanistic understanding of how fine-tuning alters internal computation.
Core Solution
To mechanistically investigate SFT impact, implement a diagnostic pipeline that projects model activations through a Sparse Autoencoder pretrained on the base model. This approach treats the SAE as a fixed dictionary, ensuring that changes in latent activations reflect genuine shifts in the model's feature usage rather than changes in the dictionary itself.
Architecture Decisions
- Fixed SAE Dictionary: Always use an SAE pretrained on the base model. Retraining the SAE on the fine-tuned model would adapt the dictionary to the new representations, obscuring the divergence. The fixed dictionary acts as a stable reference frame.
- Layer-Specific Hooking: SFT effects are not uniform across the network. Implement hooks at multiple transformer layers to capture layer-wise update profiles. This allows detection of task-specific distributions and safety alignment signatures.
- Sparse Latent Comparison: Compute divergence metrics on the sparse latent vectors rather than the dense activations. Metrics should include feature overlap, L1 divergence, and activation magnitude shifts.
Implementation
The following TypeScript implementation demonstrates a diagnostic pipeline for analyzing representational shifts. This example assumes a hypothetical ML runtime environment where model activations and SAE projections are accessible via an API.
import { ModelRunner, SaeProjection, ActivationTrace, LayerConfig } from './ml-runtime';
interface SaeDiagnosticConfig {
baseModelId: string;
sftModelId: string;
saePath: string;
targetLayers: number[];
dataset: string[];
divergenceThreshold: number;
}
interface FeatureDivergenceReport {
layerIndex: number;
cosineSimilarity: number;
saeFeatureOverlap: number;
l1Divergence: number;
shiftedFeatures: string[];
}
class MechanisticSaeAnalyzer {
private config: SaeDiagnosticConfig;
private saeProjection: SaeProjection;
constructor(config: SaeDiagnosticConfig) {
this.config = config;
// Load SAE pretrained on base model; this is the fixed dictionary
this.saeProjection = SaeProjection.load(this.config.saePath);
}
async runDiagnostic(): Promise<FeatureDivergenceReport[]> {
const reports: FeatureDivergenceReport[] = [];
for (const layerIdx of this.config.targetLayers) {
const baseTrace = await this.captureActivations(this.config.baseModelId, layerIdx);
const sftTrace = await this.captureActivations(this.config.sftModelId, layerIdx);
const baseLatents = this.saeProjection.encode(baseTrace.hiddenStates);
const sftLatents = this.saeProjection.encode(sftTrace.hiddenStates);
const report = this.computeDivergence(layerIdx, baseLatents, sftLatents);
reports.push(report);
}
return reports;
}
private async captureActivations(
modelId: string,
layerIdx: number
): Promise<ActivationTrace> {
const runner
= new ModelRunner(modelId); // Hook specific layer to extract hidden states runner.hookLayer(layerIdx); return runner.runBatch(this.config.dataset); }
private computeDivergence( layerIdx: number, baseLatents: number[][], sftLatents: number[][] ): FeatureDivergenceReport { // Compute dense cosine similarity for baseline comparison const baseDense = this.reconstructDense(baseLatents); const sftDense = this.reconstructDense(sftLatents); const cosineSim = this.calculateCosineSimilarity(baseDense, sftDense);
// Compute sparse metrics
const featureOverlap = this.calculateFeatureOverlap(baseLatents, sftLatents);
const l1Div = this.calculateL1Divergence(baseLatents, sftLatents);
// Identify specific features that shifted significantly
const shiftedFeatures = this.identifyShiftedFeatures(baseLatents, sftLatents);
return {
layerIndex: layerIdx,
cosineSimilarity: cosineSim,
saeFeatureOverlap: featureOverlap,
l1Divergence: l1Div,
shiftedFeatures: shiftedFeatures
};
}
private reconstructDense(latents: number[][]): number[][] { return latents.map(l => this.saeProjection.decode(l)); }
private calculateCosineSimilarity(a: number[][], b: number[][]): number { // Implementation of cosine similarity calculation // Returns average similarity across batch return 0.0; // Placeholder }
private calculateFeatureOverlap(a: number[][], b: number[][]): number { // Compute Jaccard index or intersection over union of top-k active features return 0.0; // Placeholder }
private calculateL1Divergence(a: number[][], b: number[][]): number { // Sum of absolute differences in latent activations return 0.0; // Placeholder }
private identifyShiftedFeatures(a: number[][], b: number[][]): string[] { // Return feature names/IDs where activation magnitude changed beyond threshold return []; // Placeholder } }
// Usage Example const config: SaeDiagnosticConfig = { baseModelId: 'llama-3-8b-base', sftModelId: 'llama-3-8b-instruct', saePath: '/models/sae/llama-3-8b-base-layer-12.sae', targetLayers: [12, 16, 20, 24], dataset: ['eval_prompts.jsonl'], divergenceThreshold: 0.5 };
const analyzer = new MechanisticSaeAnalyzer(config); analyzer.runDiagnostic().then(reports => { console.log('Diagnostic complete. Analyzing layer profiles...'); // Process reports to detect task-specific and safety alignment patterns });
#### Rationale for Design Choices
* **Layer-Specific Analysis:** The pipeline iterates over `targetLayers` to generate layer-wise reports. This is critical because SFT often affects intermediate layers differently than early or late layers. Safety alignment, for instance, may exhibit a distinct update profile concentrated in specific depth ranges.
* **Feature Identification:** The `identifyShiftedFeatures` method extracts the names or IDs of features that changed significantly. This moves beyond aggregate metrics to pinpoint exactly which semantic concepts are being altered, enabling task-specific analysis.
* **Fixed SAE Usage:** The SAE is loaded once and reused for both base and SFT projections. This ensures that the latent space is consistent, making the divergence metrics meaningful. Any change in latents must be due to the model's activation shift, not dictionary adaptation.
### Pitfall Guide
| Pitfall Name | Explanation | Fix |
| :--- | :--- | :--- |
| **The Retraining Trap** | Retraining the SAE on the fine-tuned model before analysis. This adapts the dictionary to the new representations, hiding the divergence you're trying to measure. | Always use an SAE pretrained on the base model. The SAE must serve as a fixed reference frame to detect shifts. |
| **Dense Space Illusion** | Relying solely on cosine similarity or Euclidean distance of hidden states. These metrics can remain high even when the underlying feature composition has changed completely. | Supplement dense metrics with sparse latent analysis. Compute feature overlap and L1 divergence on SAE projections. |
| **Layer Aggregation Error** | Averaging metrics across all layers, which masks layer-specific effects. SFT may alter features in intermediate layers while leaving early layers stable. | Analyze layer-wise profiles separately. Look for non-uniform update patterns and identify which layers drive the divergence. |
| **Feature Proliferation vs. Shift** | Confusing the emergence of new features with the shifting of existing ones. SFT may activate previously dormant features rather than modifying active ones. | Track feature activation frequencies. Distinguish between features that changed magnitude and features that appeared/disappeared. |
| **Safety Alignment Blindness** | Assuming safety fine-tuning behaves like general instruction tuning. Safety alignment often has a unique layer-wise update profile that requires specific detection. | Compare safety-aligned models against task-specific SFT models. Look for distinct divergence patterns in layers associated with refusal or policy enforcement. |
| **Computational Overhead** | Hooking every layer on large models during inference, causing memory bottlenecks and slow throughput. | Select a representative subset of layers based on preliminary analysis. Use gradient checkpointing or activation offloading if memory is constrained. |
| **Dictionary Sparsity Mismatch** | Using an SAE with a sparsity level that doesn't match the model's representation density, leading to poor feature resolution. | Validate SAE reconstruction loss on the base model. Choose an SAE with appropriate sparsity to capture the relevant feature granularity. |
### Production Bundle
#### Action Checklist
- [ ] **Select Base SAE:** Identify and load a Sparse Autoencoder pretrained on the base model for the target architecture and layer depth.
- [ ] **Define Layer Targets:** Choose specific layers for analysis based on model depth and expected SFT impact zones; avoid hooking all layers initially.
- [ ] **Prepare Evaluation Dataset:** Curate a dataset that covers the task domain and safety scenarios to test representational shifts comprehensively.
- [ ] **Implement Hooking Strategy:** Set up activation hooks to capture hidden states at the selected layers for both base and SFT models.
- [ ] **Compute Divergence Metrics:** Calculate cosine similarity, SAE feature overlap, and L1 divergence for each layer to quantify shifts.
- [ ] **Analyze Layer Profiles:** Review layer-wise reports to identify non-uniform update patterns and task-specific feature distributions.
- [ ] **Identify Shifted Features:** Extract the list of features with significant activation changes to understand semantic alterations.
- [ ] **Validate Safety Patterns:** If analyzing safety alignment, compare the layer-wise profile against known safety update signatures.
#### Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
| :--- | :--- | :--- | :--- |
| **Quick SFT Validation** | Cosine Similarity + Sparse Overlap on 2-3 key layers | Fast assessment of major shifts without full mechanistic analysis. | Low compute cost; minimal engineering effort. |
| **Safety Alignment Audit** | Full layer-wise SAE analysis with safety-specific dataset | Safety fine-tuning has distinct layer profiles; requires detailed feature tracking. | Moderate compute cost; requires safety dataset curation. |
| **Domain Adaptation Debugging** | SAE analysis on intermediate layers with task-specific prompts | Task-specific features often concentrate in middle layers; helps identify feature recycling. | Moderate compute cost; high diagnostic value. |
| **Model Comparison** | Base SAE projection for all candidate models | Fixed dictionary ensures fair comparison across different fine-tuning runs. | Low incremental cost; requires consistent SAE loading. |
| **Feature Engineering** | Identify shifted features and retrain SAE on SFT model | If features have shifted permanently, a new SAE may be needed for future analysis. | High compute cost; SAE training is resource-intensive. |
#### Configuration Template
Use this YAML configuration to define the diagnostic pipeline parameters. Adjust layer indices and thresholds based on model architecture and analysis goals.
```yaml
diagnostic_pipeline:
models:
base: "meta-llama/Llama-3-8B"
sft: "meta-llama/Llama-3-8B-Instruct"
sae:
path: "/checkpoints/sae/llama-3-8b-base/layer_12.sae"
k: 128 # Top-k features to consider for overlap
layers:
target_indices: [12, 16, 20, 24]
hook_mode: "forward"
dataset:
path: "/data/eval/instruction_tuning_v2.jsonl"
max_samples: 1000
metrics:
cosine_threshold: 0.95
overlap_threshold: 0.50
l1_divergence_threshold: 0.30
output:
format: "json"
save_path: "/results/diagnostic_report.json"
Quick Start Guide
- Install Dependencies: Ensure your environment has the ML runtime, SAE library, and data processing tools installed.
pip install sae-lens ml-runtime pandas numpy - Load Base SAE: Download or load the Sparse Autoencoder pretrained on the base model. Verify reconstruction loss to ensure dictionary quality.
sae = Sae.load_from_pretrained("base_model_sae_layer_12") - Run Diagnostic Script: Execute the analysis pipeline with the configuration template. This will capture activations, project through the SAE, and compute divergence metrics.
python run_sae_diagnostic.py --config config.yaml - Review Layer Profiles: Examine the output report for layer-wise divergence. Look for layers where cosine similarity is high but SAE overlap is low, indicating hidden representational shifts.
- Iterate on Layers: If initial results show significant divergence in specific layers, refine the target layer list and rerun the analysis with a larger dataset for higher confidence.
