Fine-tuning an LLM can erase everything it learned before, making it forget its original knowledge.

Let’s see what a fine-tuned model can do. Imagine we have a base LLM, llama2:7b, and we want to fine-tune it for summarization. We’ll use a simple dataset:

[
  {"instruction": "Summarize the following text:", "input": "The quick brown fox jumps over the lazy dog. This sentence is famous for containing all letters of the English alphabet.", "output": "A sentence containing all letters of the English alphabet is: The quick brown fox jumps over the lazy dog."},
  {"instruction": "Summarize the following text:", "input": "Artificial intelligence is a rapidly evolving field. It has the potential to revolutionize many industries.", "output": "AI is a fast-growing field with the potential to transform industries."}
]

We’ll use a library like axolotl for fine-tuning. A minimal config might look like this:

base_model: meta-llama/Llama-2-7b-hf
model:
  type: LlamaForCausalLM
dataset:
  path: /path/to/your/dataset.json
  type: json
optimizer:
  type: adamw_torch
  lr: 2e-5
gradient_accumulation_steps: 4
output_dir: ./fine-tuned-llama
save_steps: 100
logging_steps: 10
max_steps: 500
warmup_steps: 50
train_batch_size: 2

After running the fine-tuning process, we might have a model that’s good at summarizing our specific examples. But if we then ask it to write a poem, it might produce gibberish, or worse, respond with something like: "I cannot fulfill this request. My purpose is to summarize text." This is catastrophic forgetting.

The core problem is that during fine-tuning, the model’s weights are adjusted to minimize the loss on the new, specific task. If the new task’s data distribution is significantly different from the pre-training data, and the fine-tuning process is too aggressive (high learning rate, too many epochs/steps), the model can overwrite the general knowledge encoded in its weights. It essentially "learns" that the new task is the only thing it should know.

The most effective strategy to prevent this is Parameter-Efficient Fine-Tuning (PEFT). Instead of updating all of the model’s millions or billions of parameters, PEFT methods freeze the vast majority of the pre-trained weights and inject a small number of trainable parameters.

Here are common PEFT techniques and how they work:

  1. LoRA (Low-Rank Adaptation): This is arguably the most popular PEFT method. LoRA hypothesizes that the weight updates during fine-tuning have a low intrinsic rank. Instead of updating a large weight matrix $W$, LoRA adds two smaller, low-rank matrices $A$ and $B$ such that the update is $\Delta W = BA$. Only $A$ and $B$ are trained, while $W$ remains frozen. This drastically reduces the number of trainable parameters.

    • Diagnosis/Check: Before fine-tuning, evaluate the base model on a diverse set of general knowledge tasks (e.g., question answering, common sense reasoning, creative writing) and the target task. After fine-tuning with a full-model approach, re-evaluate. If performance on general tasks degrades significantly while target task performance improves, catastrophic forgetting has occurred.
    • Fix: Integrate LoRA into your fine-tuning setup. For example, in axolotl, you’d add a LoRA configuration:
      lora_target_modules:
        - q_proj
        - v_proj
      lora_r: 8
      lora_dropout: 0.05
      lora_alpha: 16
      
      Here, lora_r=8 means the low-rank matrices will have a rank of 8. lora_alpha=16 is a scaling factor.
    • Why it works: By training only $A$ and $B$, the original weights $W$ are untouched. The fine-tuning signal is channeled through these small adapter matrices, allowing the model to adapt to the new task without overwriting its foundational knowledge.
  2. QLoRA: An optimization of LoRA that further reduces memory footprint by quantizing the base model to 4-bit precision. This allows fine-tuning of very large models on consumer hardware.

    • Diagnosis/Check: Same as LoRA. The problem is parameter updates, not necessarily memory constraints, but QLoRA addresses both.
    • Fix: Configure QLoRA in your training script. In axolotl:
      load_4bit: true
      bnb_4bit_compute_dtype: bfloat16
      bnb_4bit_quant_type: nf4
      bnb_4bit_use_double_quant: true
      # ... then add LoRA config as above
      
    • Why it works: It combines the parameter efficiency of LoRA with memory efficiency through quantization, enabling parameter-efficient fine-tuning on larger models.
  3. Adapter Layers: Similar to LoRA, adapter layers insert small, trainable feed-forward networks (often called "adapters") within the transformer blocks, usually after the attention and feed-forward sub-layers. The original weights are frozen.

    • Diagnosis/Check: Same as LoRA.
    • Fix: Use libraries like adapter-transformers or configure adapter modules in your framework. The specific implementation details vary, but the principle is to add and train only these small adapter modules.
    • Why it works: The vast majority of the model’s parameters remain unchanged. The adapters learn task-specific transformations, allowing adaptation without disturbing pre-trained knowledge.
  4. Prompt Tuning / Prefix Tuning: These methods keep the entire LLM frozen and instead learn a small set of "soft prompts" or "prefixes" that are prepended to the input. These learned embeddings are optimized to guide the frozen model’s behavior for the specific task.

    • Diagnosis/Check: Same as LoRA.
    • Fix: Implement prompt tuning. This involves creating a trainable embedding layer that is concatenated with the input embeddings. For instance, you might train a 20-token prefix.
      # Conceptual example using PyTorch
      import torch
      import torch.nn as nn
      
      class PromptTuningModel(nn.Module):
          def __init__(self, base_model, num_virtual_tokens=20):
              super().__init__()
              self.base_model = base_model
              self.num_virtual_tokens = num_virtual_tokens
              self.virtual_tokens = nn.Parameter(torch.randn(1, num_virtual_tokens, base_model.config.hidden_size))
      
          def forward(self, input_ids, attention_mask=None, **kwargs):
              batch_size = input_ids.size(0)
              # Expand virtual tokens to match batch size
              expanded_virtual_tokens = self.virtual_tokens.expand(batch_size, -1, -1)
      
              # Get embeddings for input_ids
              input_embeds = self.base_model.get_input_embeddings()(input_ids)
      
              # Concatenate virtual tokens with input embeddings
              prompts_embeds = torch.cat([expanded_virtual_tokens, input_embeds], dim=1)
      
              # Create attention mask for virtual tokens (all ones)
              virtual_token_mask = torch.ones(batch_size, self.num_virtual_tokens, device=input_ids.device)
              if attention_mask is not None:
                  full_attention_mask = torch.cat([virtual_token_mask, attention_mask], dim=1)
              else:
                  full_attention_mask = virtual_token_mask
      
              # Pass to model
              return self.base_model(inputs_embeds=prompts_embeds, attention_mask=full_attention_mask, **kwargs)
      
    • Why it works: The base model is entirely unchanged. The learned prompt embeddings effectively steer the model’s internal computations towards the desired output for the specific task, without altering its general knowledge.
  5. Parameter Freezing: A simpler approach where you manually freeze certain layers of the model and only train the remaining ones. For example, you might freeze the first N layers and train the last M layers.

    • Diagnosis/Check: Same as LoRA.
    • Fix: In frameworks like PyTorch or TensorFlow, you can iterate through model.parameters() and set param.requires_grad = False for the layers you want to freeze.
    • Why it works: By preventing gradients from flowing into the frozen layers, their weights are preserved. This retains the general knowledge captured in those layers.
  6. Elastic Weight Consolidation (EWC): A more advanced technique that penalizes changes to weights that are important for previously learned tasks. It estimates the importance of each weight by looking at the Fisher Information Matrix from the original task.

    • Diagnosis/Check: Same as LoRA.
    • Fix: Implement a custom training loop that adds an EWC loss term to the standard task loss. This involves computing the Fisher Information Matrix for the original task (or a representative dataset) and then adding $\frac{\lambda}{2} \sum_i F_i (w_i - w_i^)^2$ to the objective, where $F_i$ is the diagonal of the Fisher matrix for weight $w_i$, and $w_i^$ is the weight value from the original task.
    • Why it works: It directly tries to balance learning the new task with retaining old knowledge by making it "costly" to deviate from important pre-trained weights.

The next challenge you’ll likely face is balancing PEFT efficiency with performance. Sometimes, a very small number of trainable parameters might not be enough to reach peak performance on a complex downstream task, leading to a trade-off between forgetting and task mastery.

Want structured learning?

Take the full Fine-tuning course →