uences and separates sentence pairs. token_type_ids distinguish between the first and second sentence in pair inputs, while attention_mask prevents the model from attending to padding tokens.
from transformers import AutoTokenizer
import torch
class ContextTokenizer:
def __init__(self, model_name: str = "bert-base-uncased", max_seq_len: int = 128):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.max_seq_len = max_seq_len
def encode_batch(self, texts: list[str], is_pair: bool = False) -> dict:
if is_pair:
# Expects list of tuples: [(sent_a, sent_b), ...]
encodings = self.tokenizer(
[pair[0] for pair in texts],
[pair[1] for pair in texts],
truncation=True,
padding="max_length",
max_length=self.max_seq_len,
return_tensors="pt"
)
else:
encodings = self.tokenizer(
texts,
truncation=True,
padding="max_length",
max_length=self.max_seq_len,
return_tensors="pt"
)
return encodings
Step 2: Architecture and Head Attachment
Instead of relying on high-level wrappers, attaching custom heads to the base encoder provides explicit control over parameter freezing and gradient flow. For classification, we extract the [CLS] embedding and pass it through a dropout layer and linear projection. For token-level tasks, we project every token's hidden state independently.
import torch.nn as nn
from transformers import AutoModel
class SequenceClassifier(nn.Module):
def __init__(self, backbone_name: str, num_classes: int, dropout_rate: float = 0.1):
super().__init__()
self.encoder = AutoModel.from_pretrained(backbone_name)
self.dropout = nn.Dropout(dropout_rate)
self.head = nn.Linear(self.encoder.config.hidden_size, num_classes)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# Extract [CLS] token representation (first token in sequence)
cls_embedding = encoder_outputs.last_hidden_state[:, 0, :]
cls_embedding = self.dropout(cls_embedding)
logits = self.head(cls_embedding)
return logits
class TokenTagger(nn.Module):
def __init__(self, backbone_name: str, num_labels: int):
super().__init__()
self.encoder = AutoModel.from_pretrained(backbone_name)
self.tag_head = nn.Linear(self.encoder.config.hidden_size, num_labels)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask
)
# Project every token position independently
token_embeddings = encoder_outputs.last_hidden_state
logits = self.tag_head(token_embeddings)
return logits
Step 3: Optimization and Training Dynamics
Transformer fine-tuning requires AdamW with weight decay and a linear learning rate schedule with warmup. The original BERT paper established that aggressive learning rates cause catastrophic forgetting of pretraining weights. A base rate of 2e-5 to 5e-5 preserves semantic knowledge while allowing task-specific adaptation. Gradient clipping prevents exploding gradients during backpropagation through deep attention layers.
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
import torch
class TrainingEngine:
def __init__(self, model: nn.Module, num_epochs: int, batch_size: int, dataset_size: int):
self.model = model
self.num_epochs = num_epochs
self.batch_size = batch_size
self.total_steps = (dataset_size // batch_size) * num_epochs
self.optimizer = AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer,
num_warmup_steps=int(0.1 * self.total_steps),
num_training_steps=self.total_steps
)
def train_step(self, batch: dict, device: torch.device) -> float:
self.model.train()
self.optimizer.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
logits = self.model(input_ids, attention_mask)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
self.scheduler.step()
return loss.item()
Architecture Rationale:
- Why
bert-base-uncased? Lowercasing reduces vocabulary size and improves generalization across casing variations. The 12-layer depth provides sufficient capacity for most classification tasks without incurring the memory overhead of bert-large.
- Why linear scheduler with warmup? Pretrained weights are already optimized. Warmup prevents early gradient spikes from destroying learned representations. Linear decay ensures stable convergence as the model adapts to the target distribution.
- Why gradient clipping? Attention mechanisms can produce large activation values during fine-tuning. Clipping at
1.0 maintains training stability without distorting the optimization landscape.
Pitfall Guide
| Pitfall | Explanation | Fix |
|---|
Ignoring attention_mask during loss computation | Padding tokens receive zero attention, but if the loss function processes all sequence positions, it averages gradients over meaningless tokens, diluting signal. | Mask the loss tensor using attention_mask before reduction, or use CrossEntropyLoss with ignore_index set to the padding ID. |
Using [CLS] embeddings for token-level tasks | The [CLS] token aggregates global sentence semantics. Applying it to NER or POS tagging forces every token to share the same representation, destroying positional specificity. | Use last_hidden_state (shape: batch, seq_len, hidden) and project each position independently. Apply sequence masking to ignore padding during loss calculation. |
Setting max_length without truncation strategy | Sequences exceeding 512 tokens cause silent truncation or runtime errors. Default padding behavior may misalign labels with tokens. | Explicitly set truncation=True and padding="max_length". Validate dataset length distribution beforehand and adjust max_length to cover 95% of samples. |
| Freezing the backbone prematurely | Freezing all encoder parameters reduces trainable weights to ~1.5K, but prevents the model from adapting contextual representations to domain-specific syntax. | Freeze only when dataset size is <1K samples or compute is severely constrained. Otherwise, fine-tune all layers with a low learning rate (2e-5 to 3e-5). |
Misinterpreting token_type_ids for single inputs | Single-sentence inputs still require token_type_ids, but they are all zeros. Assuming they are optional causes shape mismatches in some framework versions. | Always pass token_type_ids even for single sequences. The tokenizer handles this automatically when return_tensors="pt" is used. |
| Using standard SGD or Adam without weight decay | Transformers are highly prone to overfitting on small datasets. Standard optimizers lack the regularization needed to preserve pretraining knowledge. | Use AdamW with weight_decay=0.01. This decouples weight decay from gradient updates, matching the original BERT fine-tuning protocol. |
| Overlooking the 512-token hard limit | BERT's positional embeddings are trained only up to 512 positions. Feeding longer sequences causes undefined behavior or silent wrapping. | Implement chunking strategies for long documents. Process overlapping windows and aggregate predictions using voting or averaging. |
Production Bundle
Action Checklist
Decision Matrix
| Scenario | Recommended Approach | Why | Cost Impact |
|---|
| < 2K labeled samples, strict latency SLA | Freeze encoder, train classification head only | Prevents overfitting on small data; reduces compute by ~98% | Low (CPU/T4 sufficient) |
| 10K+ domain-specific samples, high accuracy required | Full fine-tuning with bert-base-uncased | Allows contextual adaptation to domain syntax and terminology | Medium (Single GPU recommended) |
| Long documents (>512 tokens), extraction task | Sliding window chunking + token-level head | BERT positional embeddings cap at 512; chunking preserves context boundaries | Medium-High (Increased inference calls) |
| Real-time API with <50ms p99 latency | distilbert-base-uncased + ONNX runtime | 60% faster inference with 97% accuracy retention; optimized for serving | Low (High throughput, low latency) |
| Multi-label classification with class imbalance | Weighted CrossEntropyLoss + focal loss variant | Prevents majority class dominance; improves recall on rare categories | Low (No infrastructure change) |
Configuration Template
# config.py
from dataclasses import dataclass
from typing import Optional
@dataclass
class BertPipelineConfig:
# Model selection
model_name: str = "bert-base-uncased"
task_type: str = "sequence_classification" # or "token_classification"
num_labels: int = 2
# Tokenization
max_sequence_length: int = 128
truncation_strategy: str = "longest_first"
# Training hyperparameters
learning_rate: float = 3e-5
weight_decay: float = 0.01
warmup_ratio: float = 0.1
batch_size: int = 32
num_epochs: int = 3
gradient_clip_norm: float = 1.0
dropout_rate: float = 0.1
# Optimization
use_mixed_precision: bool = True
freeze_backbone: bool = False
device: str = "auto" # auto, cuda, cpu
# Validation
eval_strategy: str = "epoch"
save_best_only: bool = True
metric_for_best_model: str = "f1"
Quick Start Guide
- Install dependencies:
pip install transformers torch datasets scikit-learn
- Initialize tokenizer and model: Load
bert-base-uncased, configure max_length=128, and attach a linear head matching your label count.
- Prepare dataset: Tokenize inputs with
truncation=True and padding="max_length". Map labels to integers. Split into train/validation sets.
- Configure optimizer: Instantiate
AdamW with lr=3e-5, weight_decay=0.01. Attach a linear scheduler with 10% warmup steps.
- Run training loop: Iterate over batches, compute masked loss, apply gradient clipping, step optimizer and scheduler. Validate after each epoch and checkpoint the best weights.
Encoder-only transformers remain the most efficient architecture for comprehension-heavy workloads. By aligning pretraining objectives with task requirements, managing special tokens explicitly, and respecting optimization constraints, engineering teams can deploy BERT-based pipelines that deliver production-grade accuracy at a fraction of the compute cost.