g redundant computation in the attention mechanism and ensuring the adapter does not interfere with the GQA grouping. Crucially, context window utilization remains near the architectural limit, enabling the model to process long documents or codebases without hallucination or truncation errors common in generic fine-tunes.
Core Solution
Implementing Mistral fine-tuning requires a workflow that respects the model's architectural constraints. The recommended stack uses Axolotl for configuration management and training orchestration, as it provides native support for SWA-aware packing and GQA tensor handling.
Step-by-Step Implementation
1. Environment Setup
Ensure dependencies support Flash Attention 2 and PEFT.
pip install axolotl[deepspeed] transformers peft bitsandbytes accelerate
2. Dataset Preparation
Mistral models expect a specific instruction format. The chat template uses [INST] tags. Deviating from this format breaks instruction following.
Format:
<s>[INST] {instruction} [/INST] {response}</s>
Prepare your dataset as a JSONL file where each line is a dictionary with instruction and output keys. Axolotl will apply the template automatically if configured correctly.
3. Configuration (Axolotl YAML)
The configuration must explicitly enable Flash Attention 2, set SWA-aware packing, and define LoRA targets that include GQA projections.
config.yaml:
base_model: mistralai/Mistral-7B-v0.3
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_4bit: true
strict: false
datasets:
- path: ./data/train.jsonl
type: completion
field: text
format: "{instruction} {output}"
train_on_split: train
dataset_prepared_path:
- last_run_prepared_ds
output_dir: ./mistral-finetuned
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
wandb_project: mistral-finetune
wandb_entity:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
flash_attention: true
logging_steps: 10
save_steps: 100
eval_steps: 100
save_total_limit: 3
Key Configuration Rationale:
sequence_len: 4096: Matches the Sliding Window size. Packing sequences longer than this without SWA masking breaks the attention mechanism.
sample_packing: true: Enables efficient packing but must be used with sequence_len aligned to the SWA window.
lora_r: 32: Mistral's hidden size is 4096. A rank of 32 provides sufficient capacity without overfitting, balancing the GQA structure.
learning_rate: 2e-5: Lower than standard Llama recipes due to Mistral's sensitivity.
lora_target_modules: Includes all linear projections. Mistral's GQA means k_proj and v_proj have shapes (num_key_value_heads * head_dim, hidden_size), which PEFT handles automatically when specified.
4. Training Execution
Run the training command:
accelerate launch -m axolotl.cli.train config.yaml
5. Merging and Export
After training, merge the adapter and export to GGUF or AWQ for deployment.
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
base_model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.3",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.3")
model = PeftModel.from_pretrained(base_model, "./mistral-finetuned/checkpoint-final")
model = model.merge_and_unload()
model.save_pretrained("./merged-mistral")
tokenizer.save_pretrained("./merged-mistral")
Pitfall Guide
1. Sliding Window Boundary Violation
Mistake: Packing sequences across the 4096-token boundary without masking.
Impact: Attention scores leak between unrelated sequences, causing loss spikes and corrupted gradients.
Fix: Use sequence_len: 4096 and enable SWA-aware packing in your trainer. Ensure no single packed sequence exceeds the window size unless the trainer implements sliding window masking.
2. Chat Template Mismatch
Mistake: Using Alpaca or generic instruction templates.
Impact: The model fails to recognize instruction boundaries, resulting in verbose or non-compliant outputs.
Fix: Strictly adhere to <s>[INST] ... [/INST] ... </s> format. Verify the tokenizer's apply_chat_template output matches this structure.
3. Learning Rate Too High
Mistake: Using 1e-4 or higher learning rates common for other models.
Impact: Instability during early training steps; loss divergence.
Fix: Mistral requires lower learning rates. Start with 2e-5 to 5e-5 for LoRA. Use cosine scheduling with warmup.
4. Context Window Collapse
Mistake: Fine-tuning exclusively on short sequences (e.g., <2k tokens).
Impact: The model loses the ability to attend to tokens beyond the fine-tuning distribution length, effectively reducing the context window.
Fix: Include a mix of sequence lengths in the dataset, up to the maximum context window. Use group_by_length to batch similar lengths efficiently.
5. GQA Quantization Errors
Mistake: Quantizing the model without preserving GQA structure.
Impact: Increased perplexity and degraded generation quality due to misaligned key/value heads.
Fix: When quantizing to GGUF/AWQ, ensure the quantization tool respects the grouped query attention configuration. Use llama.cpp with proper Mistral flags or autoawq with GQA support.
6. LoRA Rank Mismatch
Mistake: Using excessively high ranks (e.g., 128+) on 7B models.
Impact: Overfitting on small datasets; increased inference latency without accuracy gains.
Fix: For Mistral-7B, ranks between 16 and 64 are optimal. Use rank 32 as a baseline and adjust based on dataset size.
Mistake: Setting train_on_inputs: true for instruction tuning.
Impact: The model learns to predict the instruction tokens, wasting capacity and degrading instruction following.
Fix: Set train_on_inputs: false so the model only learns to predict the response tokens.
Production Bundle
Action Checklist
Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
|---|
| Low Budget / Single GPU | QLoRA (4-bit) with Rank 32 | Reduces VRAM by ~70% while maintaining accuracy; enables fine-tuning on consumer hardware. | Low |
| High Quality / Full Retraining | Full Fine-tuning BF16 | Best convergence for large datasets; avoids adapter merge artifacts; maximizes capability transfer. | High |
| Long Context Critical | SFT with SWA-Aware Packing | Preserves 32k context window; prevents context collapse; essential for RAG or document processing. | Medium |
| Latency Sensitive Inference | Merge + AWQ Quantization | AWQ provides better latency/perplexity trade-off than GGUF for Mistral; reduces memory bandwidth pressure. | Medium |
Configuration Template
axolotl_mistral_production.yaml:
base_model: mistralai/Mistral-7B-v0.3
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
load_in_4bit: true
strict: false
datasets:
- path: ./data/train.jsonl
type: completion
field: text
format: "{instruction} {output}"
output_dir: ./output/mistral-prod
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_r: 32
lora_alpha: 64
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 3
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
bf16: true
fp16: false
tf32: true
gradient_checkpointing: true
flash_attention: true
logging_steps: 10
save_steps: 200
eval_steps: 200
save_total_limit: 2
deepspeed: null
Quick Start Guide
-
Install Dependencies:
pip install axolotl[deepspeed] transformers peft bitsandbytes accelerate
-
Prepare Data:
Create train.jsonl with Mistral format:
{"instruction": "Explain quantum computing.", "output": "Quantum computing uses qubits..."}
-
Create Config:
Save the axolotl_mistral_production.yaml template above.
-
Run Training:
accelerate launch -m axolotl.cli.train axolotl_mistral_production.yaml
-
Validate:
Load the merged model and test with a prompt wrapped in [INST] tags. Verify response format and latency metrics.