The "best" checkpoint after fine-tuning isn’t necessarily the one with the highest score on your validation set; it’s the one that performs best on the specific task you care about, which might be a different distribution or metric entirely.
Let’s say you’ve fine-tuned a large language model (LLM) for sentiment analysis on customer reviews. Your goal is to identify negative reviews to route them to customer support. You’ve trained it on a dataset and now have several checkpoints from different epochs.
Here’s a hypothetical training run and how you might evaluate checkpoints:
Training Setup:
- Base Model:
meta-llama/Llama-2-7b-hf - Dataset: 10,000 customer reviews, labeled positive/negative.
- Fine-tuning Task: Binary classification (positive/negative sentiment).
- Optimizer: AdamW
- Learning Rate: 2e-5
- Batch Size: 16
- Epochs: 5
- Validation Set: 1,000 reviews, held out from training.
- Checkpoints Saved: After each epoch.
Hypothetical Validation Scores (Accuracy):
- Epoch 1: 85.2%
- Epoch 2: 88.5%
- Epoch 3: 90.1%
- Epoch 4: 90.5%
- Epoch 5: 90.3%
At first glance, the checkpoint from Epoch 4 looks like the winner because it has the highest validation accuracy. But if your primary goal is to catch all negative reviews (minimize false negatives, or maximize recall for the "negative" class), accuracy alone might be misleading.
Let’s dive into the actual performance on the validation set, looking at precision and recall for the "negative" class.
Evaluation Metrics (Validation Set):
| Checkpoint | Accuracy | Precision (Negative) | Recall (Negative) | F1 Score (Negative) |
|---|---|---|---|---|
| Epoch 1 | 85.2% | 80.1% | 75.5% | 77.7% |
| Epoch 2 | 88.5% | 85.5% | 82.0% | 83.7% |
| Epoch 3 | 90.1% | 88.2% | 85.8% | 87.0% |
| Epoch 4 | 90.5% | 89.0% | 86.5% | 87.7% |
| Epoch 5 | 90.3% | 88.5% | 87.0% | 87.7% |
Here, Epoch 4 and Epoch 5 have identical F1 scores for the negative class. However, Epoch 5 has slightly higher recall (87.0%) compared to Epoch 4 (86.5%). If your business priority is to ensure that as few negative reviews as possible slip through unnoticed, then Epoch 5 might be the better choice, even though its overall accuracy is marginally lower than Epoch 4. It successfully identifies a higher percentage of the actual negative reviews.
The key is to align your checkpoint selection with your business objective and the specific metric that quantifies success for that objective. This often means looking beyond aggregate accuracy to class-specific metrics like precision, recall, or custom evaluation functions.
Consider a scenario where you’re fine-tuning a model to detect fraudulent transactions. You might have millions of legitimate transactions and only a few thousand fraudulent ones. A model that predicts "not fraud" for everything would have near-perfect accuracy but would be useless. In such a case, recall for the "fraud" class is paramount. You’d choose the checkpoint that maximizes this, even if it means a slight dip in overall accuracy or precision (leading to more false positives, which are easier to handle than missed fraud).
The most surprising true thing about selecting a fine-tuned checkpoint is that the validation set’s top score is often a red herring, and your true "best" checkpoint is hidden in the nuances of class-specific performance metrics that directly map to your ultimate business goal.
Let’s explore how to programmatically load and evaluate these checkpoints using the Hugging Face transformers library. Assume you have your model and tokenizer loaded and saved checkpoints in a directory like ./checkpoints/.
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_metric
import torch
# Assume these are defined from your training script
model_name = "meta-llama/Llama-2-7b-hf"
num_labels = 2 # Positive, Negative
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load your validation dataset (replace with your actual dataset loading)
# For demonstration, let's create dummy data
class DummyDataset:
def __init__(self, texts, labels):
self.texts = texts
self.labels = labels
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
return {"text": self.texts[idx], "label": self.labels[idx]}
# Example dummy validation data
validation_texts = [
"This product is amazing, I love it!",
"The service was terrible, I'm very unhappy.",
"It works as expected, no complaints.",
"Worst experience ever, would not recommend.",
"So-so, could be better.",
"Absolutely fantastic! Highly satisfied.",
"Disappointing quality, broke after a week.",
"Great value for money.",
"Completely useless, a waste of money.",
"It's okay, nothing special."
]
# Corresponding labels (0: NEGATIVE, 1: POSITIVE)
validation_labels = [1, 0, 1, 0, 1, 1, 0, 1, 0, 1]
validation_dataset = DummyDataset(validation_texts, validation_labels)
# Load the accuracy metric (and others if needed)
metric = load_metric("accuracy")
# For more detailed metrics:
# from evaluate import load
# metric = load("f1", average="macro") # or "micro", "weighted", or specify pos_label
# Function to evaluate a checkpoint
def evaluate_checkpoint(checkpoint_path, tokenizer, dataset):
model = AutoModelForSequenceClassification.from_pretrained(
checkpoint_path,
num_labels=num_labels,
id2label=id2label,
label2id=label2id
)
model.eval() # Set model to evaluation mode
predictions = []
true_labels = []
for example in dataset:
inputs = tokenizer(example["text"], return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = logits.argmax().item()
predictions.append(predicted_class_id)
true_labels.append(example["label"])
# Calculate metrics
metric.add_batch(predictions=predictions, references=true_labels)
results = metric.compute()
# For more detailed metrics, you'd compute them manually or use other libraries
# Example for precision/recall/f1 for a specific class (e.g., NEGATIVE, label_id=0)
from sklearn.metrics import classification_report
report = classification_report(true_labels, predictions, target_names=["NEGATIVE", "POSITIVE"], output_dict=True)
return {
"accuracy": results["accuracy"],
"precision_negative": report["NEGATIVE"]["precision"],
"recall_negative": report["NEGATIVE"]["recall"],
"f1_negative": report["NEGATIVE"]["f1-score"],
"precision_positive": report["POSITIVE"]["precision"],
"recall_positive": report["POSITIVE"]["recall"],
"f1_positive": report["POSITIVE"]["f1-score"],
}
# Iterate through checkpoints and evaluate
checkpoint_dir = "./checkpoints/" # Replace with your actual checkpoint directory
import os
checkpoint_epochs = sorted([d for d in os.listdir(checkpoint_dir) if os.path.isdir(os.path.join(checkpoint_dir, d))])
# Assuming checkpoints are named like 'checkpoint-1', 'checkpoint-2', etc.
# Or if they are just epoch numbers, adjust accordingly.
# For simplicity, let's assume they are saved as directories named 'epoch_X' or similar.
# Let's simulate checkpoint directory structure for demonstration:
# ./checkpoints/epoch_1/
# ./checkpoints/epoch_2/
# ...
# Create dummy checkpoint directories and files for demonstration
os.makedirs("./checkpoints/epoch_1", exist_ok=True)
os.makedirs("./checkpoints/epoch_2", exist_ok=True)
os.makedirs("./checkpoints/epoch_3", exist_ok=True)
os.makedirs("./checkpoints/epoch_4", exist_ok=True)
os.makedirs("./checkpoints/epoch_5", exist_ok=True)
# Save a dummy model config to each directory to simulate checkpoint save
dummy_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, id2label=id2label, label2id=label2id)
for i in range(1, 6):
dummy_model.save_pretrained(f"./checkpoints/epoch_{i}")
print("Evaluating checkpoints...")
all_results = {}
for epoch_dir in checkpoint_epochs:
checkpoint_path = os.path.join(checkpoint_dir, epoch_dir)
print(f"Evaluating: {checkpoint_path}")
# In a real scenario, you might need to adjust the dataset for each epoch if it was augmented differently
# For this example, we use the same validation dataset for all
results = evaluate_checkpoint(checkpoint_path, tokenizer, validation_dataset)
all_results[epoch_dir] = results
print(f" Results: {results}")
# Find the best checkpoint based on a specific metric, e.g., recall for negative class
best_checkpoint = None
best_recall_negative = -1.0
for epoch, metrics in all_results.items():
if metrics["recall_negative"] > best_recall_negative:
best_recall_negative = metrics["recall_negative"]
best_checkpoint = epoch
print("\n--- Summary ---")
for epoch, metrics in all_results.items():
print(f"{epoch}: Accuracy={metrics['accuracy']:.4f}, Recall(Negative)={metrics['recall_negative']:.4f}, F1(Negative)={metrics['f1_negative']:.4f}")
print(f"\nBest checkpoint based on maximizing Recall for Negative class: {best_checkpoint} (Recall: {best_recall_negative:.4f})")
# You would then load this best checkpoint for inference
# best_model_path = os.path.join(checkpoint_dir, best_checkpoint)
# final_model = AutoModelForSequenceClassification.from_pretrained(best_model_path)
# final_tokenizer = AutoTokenizer.from_pretrained(best_model_path) # or from original model_name if tokenizer wasn't saved
The critical insight is that "best" is contextual. If your model is used in an A/B test where you want to serve users the most engaging content (e.g., highest click-through rate prediction), you might prioritize a metric that correlates with that engagement, even if it means a slightly lower overall prediction accuracy on a generic validation set. The goal is to have a robust evaluation methodology that mirrors your production use case, rather than blindly trusting validation set accuracy.