Fine-tuning pretrained language models has become the standard approach for NLP tasks, but the gap between "load a model" and "production-ready classifier" involves many practical decisions. In this guide, I walk through fine-tuning DistilBERT for text classification, covering the full pipeline: dataset preparation, handling class imbalance with weighted loss functions, implementing custom data collators for efficient batching, and building training loops with proper evaluation. We'll use PyTorch and HuggingFace Transformers throughout, with all code available on GitHub.
📦 View Full Code on GitHubUnderstanding DistilBERT
DistilBERT is a distilled version of BERT-base, designed to be smaller and faster while retaining 97% of BERT's performance. The key specifications:
- Embedding dimension: 768
- Vocabulary size: 30,522 tokens
- Layers: 6 (compared to BERT's 12)
- Parameters: ~66 million (compared to BERT's 110 million)
This makes DistilBERT an excellent choice for classification tasks where you need a balance between performance and computational efficiency. The reduced size means faster training and inference without significant accuracy loss.
Project Overview
Our implementation covers the following components:
- Loading and preparing a text classification dataset
- Tokenizing text with proper padding and truncation
- Handling class imbalance with weighted loss functions
- Creating custom DataLoaders with efficient batching
- Fine-tuning the model with a proper training loop
- Evaluating performance with appropriate metrics
1. Dataset Preparation
We'll use the HuggingFace datasets library to load our data. For this example, I'm using a sentiment classification dataset, but the approach generalizes to any text classification task.
from datasets import load_dataset
from transformers import AutoTokenizer
# Load dataset
dataset = load_dataset('your-dataset-name')
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
# Tokenization function
def tokenize_function(examples):
return tokenizer(
examples['text'],
padding='max_length',
truncation=True,
max_length=128
)
# Apply tokenization
tokenized_dataset = dataset.map(tokenize_function, batched=True)
padding='max_length' here for simplicity, but in production, dynamic padding through a data collator is more efficient. We'll implement that next.
2. Custom Data Collator for Dynamic Padding
Instead of padding all sequences to the maximum length upfront, we can pad dynamically to the longest sequence in each batch. This is more memory-efficient and faster.
from transformers import DataCollatorWithPadding
# Create data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# The collator will:
# 1. Take variable-length sequences in a batch
# 2. Pad them to the longest sequence in that specific batch
# 3. Create attention masks automatically
How the Data Collator Works
The data collator transforms a batch of samples with varying lengths into a properly padded tensor. Here's what happens:
# Input: 4 samples with different lengths
Sample 1: [101, 2054, 2003, 102] # length 4
Sample 2: [101, 1045, 2293, 3269, 102] # length 5
Sample 3: [101, 6920, 102] # length 3
Sample 4: [101, 2023, 2003, 1037, 6240, 6251, 102] # length 7
# Output: Padded batch (all length 7)
input_ids: [
[101, 2054, 2003, 102, 0, 0, 0], # padded with 0s
[101, 1045, 2293, 3269, 102, 0, 0], # padded with 0s
[101, 6920, 102, 0, 0, 0, 0], # padded with 0s
[101, 2023, 2003, 1037, 6240, 6251, 102] # no padding needed
]
attention_mask: [
[1, 1, 1, 1, 0, 0, 0], # 1=real token, 0=padding
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1]
]
3. Handling Class Imbalance
Real-world datasets often have imbalanced classes. We can handle this by computing class weights and using them in our loss function.
import torch
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
def calculate_class_weights(train_dataset, device):
"""
Calculates class weights for handling imbalanced datasets.
"""
# Extract labels from training set
train_labels = [train_dataset.dataset.labels[i]
for i in train_dataset.indices]
# Compute class weights using sklearn
class_weights = compute_class_weight(
class_weight='balanced',
classes=np.unique(train_labels),
y=np.array(train_labels)
)
# Convert to PyTorch tensor
class_weights_tensor = torch.tensor(
class_weights,
dtype=torch.float
)
return class_weights_tensor.to(device)
# Calculate weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_weights = calculate_class_weights(train_dataset, device)
4. Loading the Model
We use AutoModelForSequenceClassification which automatically adds a classification head on top of DistilBERT.
from transformers import AutoModelForSequenceClassification
# Load pretrained model with classification head
model = AutoModelForSequenceClassification.from_pretrained(
'distilbert-base-uncased',
num_labels=num_classes # e.g., 3 for 3-class classification
)
# Move to device
model = model.to(device)
# The model architecture:
# Input (batch_size, seq_len)
# → DistilBERT encoder (768-dim hidden states)
# → Classification head (768 → num_classes)
# → Output logits (batch_size, num_classes)
5. Creating DataLoaders
DataLoaders handle batching, shuffling, and parallel data loading.
from torch.utils.data import DataLoader
def create_dataloaders(train_dataset, val_dataset, batch_size, data_collator):
"""
Creates training and validation DataLoaders.
"""
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True, # Shuffle training data
collate_fn=data_collator, # Use our custom collator
num_workers=4 # Parallel data loading
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False, # Don't shuffle validation
collate_fn=data_collator,
num_workers=4
)
return train_loader, val_loader
# Create loaders
train_loader, val_loader = create_dataloaders(
train_dataset,
val_dataset,
batch_size=32,
data_collator=data_collator
)
Understanding DataLoaders
A DataLoader is an iterator that serves batches of data. Think of it as a primed pipeline that only loads data when you actually iterate through it:
# train_loader is just the pipeline (no data loaded yet)
train_loader = DataLoader(...)
# Data is loaded when you iterate
for batch in train_loader:
# First iteration: samples 0-31
# Second iteration: samples 32-63
# And so on...
inputs = batch['input_ids']
labels = batch['labels']
6. Training Loop
Now we implement the training loop with proper loss calculation using class weights.
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
# Setup
optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = CrossEntropyLoss(weight=class_weights)
num_epochs = 3
# Training loop
for epoch in range(num_epochs):
model.train()
total_loss = 0
progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
for batch in progress_bar:
# Move batch to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
# Forward pass
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
# Calculate loss with class weights
loss = loss_fn(outputs.logits, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Track loss
total_loss += loss.item()
progress_bar.set_postfix({'loss': loss.item()})
avg_loss = total_loss / len(train_loader)
print(f'Epoch {epoch+1} - Average Loss: {avg_loss:.4f}')
# Validation
model.eval()
val_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch in val_loader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
loss = loss_fn(outputs.logits, labels)
val_loss += loss.item()
# Calculate accuracy
predictions = torch.argmax(outputs.logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
avg_val_loss = val_loss / len(val_loader)
accuracy = correct / total
print(f'Validation Loss: {avg_val_loss:.4f}, Accuracy: {accuracy:.4f}')
7. Evaluation Metrics
Beyond accuracy, we should look at precision, recall, and F1-score, especially for imbalanced datasets.
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
def evaluate_model(model, dataloader, device):
"""
Comprehensive evaluation with multiple metrics.
"""
model.eval()
all_predictions = []
all_labels = []
with torch.no_grad():
for batch in dataloader:
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask)
predictions = torch.argmax(outputs.logits, dim=1)
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Classification report
print(classification_report(all_labels, all_predictions))
# Confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.title('Confusion Matrix')
plt.show()
# Run evaluation
evaluate_model(model, val_loader, device)
Key Takeaways
- Dynamic Padding: Use
DataCollatorWithPaddingfor efficient memory usage instead of padding all sequences to max length - Class Imbalance: Calculate and apply class weights to prevent the model from being biased toward majority classes
- Proper Evaluation: Look beyond accuracy—use F1-score, precision, and recall, especially for imbalanced datasets
- Device Management: Always move both model and data to the same device (CPU/GPU)
- Gradient Management: Remember to call
optimizer.zero_grad()before each backward pass
Next Steps
This implementation provides a solid foundation, but there are several ways to improve it:
- Add learning rate scheduling for better convergence
- Implement early stopping to prevent overfitting
- Use gradient accumulation for larger effective batch sizes
- Add model checkpointing to save best models
- Experiment with different optimizers (AdamW vs Adam vs SGD)
- Try different pretrained models (RoBERTa, ALBERT, etc.)