The most surprising truth about supervised fine-tuning (SFT) is that it often doesn’t improve a model’s reasoning ability as much as it improves its style and adherence to instructions.

Let’s see this in action. We’ll fine-tune meta-llama/Llama-2-7b-hf on a simple instruction dataset. First, ensure you have transformers, datasets, and trl installed:

pip install transformers datasets trl peft

Here’s the Python code to set up the SFT process:

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from peft import LoraConfig

# 1. Load Dataset
dataset = load_dataset("imdb", split="train[:1%]") # Using a small subset of IMDB for demo

# 2. Load Model and Tokenizer
model_name = "meta-llama/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set padding token if not already set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16, # Use bfloat16 for efficiency
    device_map="auto" # Automatically distribute across available GPUs
)

# 3. Configure LoRA (for efficient fine-tuning)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# 4. Define Training Arguments
training_args = TrainingArguments(
    output_dir="./llama2-7b-sft-imdb",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    num_train_epochs=1,
    logging_steps=10,
    save_steps=50,
    fp16=False, # Set to False if using bfloat16
    bf16=True, # Enable bfloat16 training
    warmup_steps=10,
    report_to="none"
)

# 5. Initialize SFT Trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    dataset_text_field="text", # The column in your dataset containing the text
    peft_config=lora_config,
    max_seq_length=512, # Truncate sequences to 512 tokens
    packing=False # For simplicity, not packing sequences
)

# 6. Train
print("Starting training...")
trainer.train()
print("Training finished!")

# 7. Save Model
trainer.save_model("./llama2-7b-sft-imdb-final")
print("Model saved!")

This script loads a base model (meta-llama/Llama-2-7b-hf), prepares a small portion of the IMDB dataset (treating movie reviews as text to be "completed" or styled), configures LoRA for efficient parameter updates, and then uses TRL’s SFTTrainer to perform the fine-tuning. The dataset_text_field="text" tells the trainer which column in the loaded dataset contains the text data it should learn from. max_seq_length=512 ensures that sequences longer than 512 tokens are truncated to avoid memory issues, and packing=False means each training example is processed independently.

The core of this process is the SFTTrainer. It internally handles the tokenization, batching, and the forward/backward passes, specifically optimizing for the next-token prediction objective. The peft_config allows us to use LoRA, which injects small, trainable matrices into specific layers of the pre-trained model. This drastically reduces the number of parameters we need to update, making fine-tuning feasible on consumer hardware and preventing catastrophic forgetting of the model’s original capabilities. The TrainingArguments control the optimization process: per_device_train_batch_size and gradient_accumulation_steps determine the effective batch size, learning_rate and num_train_epochs set the optimization schedule, and bf16=True enables mixed-precision training for speed and memory savings.

The mental model for SFT is simple: you show the model examples of the output format and style you want. If you want a summarization model, you feed it pairs of (long_text, summary). If you want a chatbot, you feed it (user_prompt, bot_response). The model learns to predict the next token in the desired output sequence, conditioning on the input prompt. It’s essentially learning a new "writing style" or "persona" based on the provided examples.

What most people don’t realize is that the quality and format of the input data are paramount. If your dataset consists of simple, repetitive instructions, the model will become very good at following those specific instructions and generating text in that specific style. However, if the underlying reasoning or knowledge required to perform those instructions isn’t sufficiently present or implicitly learned from the vast pre-training data, SFT alone won’t magically imbue it. It’s like teaching someone to perfectly mimic a famous orator’s cadence and tone without necessarily improving their original thoughts. The dataset_text_field is the crucial handle here; it’s the direct pipeline into what the model actually sees as the target output.

After successfully fine-tuning, the next hurdle is evaluating whether the model has actually improved on the desired task beyond just stylistic imitation, often leading to exploring Reinforcement Learning from Human Feedback (RLHF) or Direct Preference Optimization (DPO).

Want structured learning?

Take the full Fine-tuning course →