Vision-language models can learn to reason about images and text simultaneously, but fine-tuning them for specific tasks often leads to catastrophic forgetting.
Let’s see this in action. Imagine we have a base model, like CLIP, trained on a massive dataset of image-text pairs. We want to adapt it for a specific task: visual question answering (VQA). A typical VQA dataset might have images and questions like "What color is the car?" with answers like "red."
# Assume 'base_model' is a pre-trained CLIP model
# Assume 'vqa_dataset' is a dataset of (image, question, answer) tuples
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader
import torch
# Load pre-trained model and processor
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Prepare data loader (simplified)
class VQADataset(torch.utils.data.Dataset):
def __init__(self, dataset, processor):
self.dataset = dataset
self.processor = processor
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
image = item['image']
question = item['question']
answer = item['answer']
inputs = self.processor(text=[question], images=image, return_tensors="pt", padding=True, truncation=True)
# For VQA, we often map answers to text labels or use a classification head
# For simplicity, let's assume we're predicting the "best" text description among candidates
# In a real VQA, this would be more complex, often involving generating text or a specific answer token.
# Here, we'll prepare inputs for a hypothetical classification setup where the answer is a text label.
return {
"pixel_values": inputs["pixel_values"].squeeze(),
"input_ids": inputs["input_ids"].squeeze(),
"attention_mask": inputs["attention_mask"].squeeze(),
"labels": answer # This would need to be converted to a numerical label or token ID
}
# Dummy dataset for illustration
dummy_data = [
{'image': Image.new('RGB', (224, 224), color = 'red'), 'question': 'What color is the car?', 'answer': 'red'},
{'image': Image.new('RGB', (224, 224), color = 'blue'), 'question': 'What color is the sky?', 'answer': 'blue'},
]
vqa_dataset = VQADataset(dummy_data, processor)
vqa_dataloader = DataLoader(vqa_dataset, batch_size=2)
# A simplified training loop for fine-tuning
# In a real scenario, you'd need a specific VQA head on top of CLIP
# and a loss function appropriate for VQA (e.g., cross-entropy if predicting from candidates)
# Example: Adding a classification head for VQA
class CLIPForVQA(torch.nn.Module):
def __init__(self, model_name="openai/clip-vit-base-patch32", num_answers=100): # num_answers is hypothetical
super().__init__()
self.clip = CLIPModel.from_pretrained(model_name)
# A linear layer to map CLIP's text encoder output to answer classes
self.vqa_head = torch.nn.Linear(self.clip.config.text_config.hidden_size, num_answers)
self.num_answers = num_answers
def forward(self, pixel_values, input_ids, attention_mask, labels=None):
outputs = self.clip(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
# Use the text encoder's pooled output for classification
text_features = outputs.text_features
logits = self.vqa_head(text_features)
loss = None
if labels is not None:
# This assumes labels are integer class IDs.
# In a real VQA task, the input to the model would be (image, question),
# and the model would predict the answer.
# Here, we're simplifying to show the output of the text encoder.
# A proper VQA setup would involve combining image and text features more deeply
# or using a generative approach.
# For this example, let's assume a simple classification loss on text features.
# In a true VQA, you'd likely predict from candidate answers or generate text.
# We'll simulate a loss for demonstration.
# A common VQA approach is to embed image and text, then predict similarity or a specific answer.
# Let's imagine we're predicting the index of the correct answer from a fixed set.
# If labels are text, they'd need mapping.
# For this example, let's pretend labels are class indices.
# If labels are actual text answers, you'd need a mapping to indices.
# For a more direct VQA setup, you might use a contrastive loss on image/text embeddings
# or a generative head.
# This part is highly simplified for demonstration.
# A real VQA would involve a more complex interaction or prediction strategy.
# Let's assume labels are already numerical indices for simplicity.
# We'll use a dummy loss calculation.
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_answers), labels) # Assuming labels are class indices
return {"logits": logits, "loss": loss}
# Instantiate the VQA model
vqa_model = CLIPForVQA()
# Training (simplified)
optimizer = torch.optim.AdamW(vqa_model.parameters(), lr=5e-5)
# Dummy labels for the dummy data
dummy_labels = torch.tensor([0, 1]) # Assuming 'red' is class 0, 'blue' is class 1
for epoch in range(1): # One epoch for demo
for batch in vqa_dataloader:
optimizer.zero_grad()
outputs = vqa_model(
pixel_values=batch["pixel_values"],
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=dummy_labels # Using dummy labels
)
loss = outputs.loss
loss.backward()
optimizer.step()
print(f"Batch loss: {loss.item()}")
print("Fine-tuning complete (simplified).")
The core problem is that vision-language models like CLIP learn a joint embedding space by contrasting positive image-text pairs against negative ones. This contrastive objective forces the model to learn general representations of visual concepts and their textual descriptions. When you fine-tune this model on a downstream task, like VQA, you’re essentially asking it to perform a different objective (e.g., answering questions about an image). If you train it directly on the VQA dataset with a standard supervised loss, the model’s parameters will shift to optimize for the VQA objective, often at the expense of its original, general-purpose multimodal understanding. This leads to "catastrophic forgetting," where the model becomes good at VQA but terrible at tasks it was originally trained on, like image-text retrieval.
The mental model for these models is a shared embedding space. Imagine two parallel universes, one for images and one for text. A pre-trained vision-language model builds a bridge between these universes, ensuring that an image of a dog and the text "a photo of a dog" land in roughly the same spot in a high-dimensional abstract space. When fine-tuning, you’re moving this bridge, or parts of it, to a new location that’s optimal for a new kind of dialogue between the universes, say, answering questions about images. The problem is that the original bridge might get distorted or even collapse.
The key levers you control are:
- The Fine-tuning Dataset: Its size, diversity, and how closely it mirrors the target task.
- The Fine-tuning Objective/Loss Function: What you’re optimizing for during training. For VQA, this might be a classification loss over candidate answers, a generative loss, or a retrieval-style loss.
- The Model Architecture: Whether you’re fine-tuning the entire model, just the final layers, or using adapter modules.
- Hyperparameters: Learning rate, batch size, number of epochs, and regularization techniques.
A common counterintuitive finding is that simply training on a VQA dataset with a standard cross-entropy loss on predicted answers can severely degrade performance on the original retrieval task. This is because the model’s internal representations, shaped by contrastive learning, are fundamentally different from what’s needed for direct classification or generation. The model learns to "map" an image and a question to an answer, but it might lose the ability to map an image to its description or vice-versa effectively. The contrastive loss encourages a global alignment, while VQA often requires local, specific reasoning.
The most surprising aspect for many is how deeply the contrastive pre-training objective influences the model’s learned representations. It’s not just about learning "dog" and "dog image" are similar; it’s about learning a sophisticated metric space where relative similarities and differences are paramount. When you switch to a task like VQA, which might optimize for absolute correctness of a specific answer rather than relative similarity to a correct caption, you’re asking the model to operate in a different "geometry" of this space. This is why simple fine-tuning often fails to preserve the original capabilities.
The next challenge is efficiently adapting these large models to multiple downstream tasks without requiring separate, fully fine-tuned copies for each.