A learning rate schedule is the most powerful, yet least understood, knob for stable LLM fine-tuning.

Let’s watch it in action. Imagine we’re fine-tuning Llama 2 7B on a custom dataset for code generation. We start with a base learning rate of 1e-5.

from transformers import Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model

# Load model and tokenizer (assuming you have them)
# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Load your dataset
# dataset = load_dataset("your_dataset_name")

# Configure LoRA
# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     lora_dropout=0.05,
#     bias="none",
#     task_type="CAUSAL_LM"
# )
# model = get_peft_model(model, lora_config)

# Define training arguments with a cosine decay schedule
training_args = TrainingArguments(
    output_dir="./llama2-7b-finetuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=1e-5,
    num_train_epochs=3,
    logging_steps=10,
    save_steps=500,
    lr_scheduler_type="cosine",  # <<< The key scheduler type
    warmup_steps=200,            # <<< And the warmup duration
    fp16=True,
    # ... other args
)

# Initialize Trainer
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=dataset["train"],
#     # ... other args
# )

# Start training
# trainer.train()

This lr_scheduler_type="cosine" with warmup_steps=200 is what we’re focusing on. It’s not just about a single learning rate; it’s about how that rate changes over time.

The core problem LLM fine-tuning faces is instability. Large models have millions or billions of parameters, and even small gradients can cause significant shifts, leading to catastrophic forgetting or divergence. The learning rate schedule is the primary mechanism to combat this. It starts low, increases gradually, and then decays, guiding the model through the parameter space more gently.

Here’s how the cosine decay schedule works under the hood:

  1. Warmup: For the first warmup_steps (e.g., 200 steps), the learning rate linearly increases from 0 to the initial learning_rate (e.g., 1e-5). This prevents large, disruptive updates at the very beginning when the model’s weights are still in a general state.
  2. Decay: After the warmup, the learning rate follows a cosine curve. It starts at 1e-5 and gradually decreases towards 0 over the remaining training steps. The cosine function ensures a smooth, non-linear decay, which is often more effective than a simple linear decay for LLMs. The rate is lr = final_lr + 0.5 * (initial_lr - final_lr) * (1 + cos(pi * t / T)), where t is the current step, T is the total number of training steps, and final_lr is typically 0.

The magic of the cosine schedule is its balance. The warmup phase protects against initial shock, while the smooth decay allows the model to converge to a good minimum without overshooting or oscillating. It’s like slowly easing into a turn on a racetrack rather than slamming on the brakes and then flooring it.

The exact values for warmup_steps and learning_rate are crucial. A common heuristic is to set warmup_steps to a small percentage (e.g., 5-10%) of the total training steps. For instance, if you have 4000 total training steps, 200-400 warmup steps is a good starting point. The learning_rate itself depends heavily on the model, task, and dataset size, but for fine-tuning large models like Llama 2, values between 1e-5 and 5e-5 are typical when using techniques like LoRA.

The one thing most people don’t realize is how sensitive the end point of the schedule is. While the learning_rate argument typically specifies the peak rate after warmup, some frameworks might interpret it differently or allow you to specify a min_learning_rate. For cosine decay, if a min_learning_rate is not explicitly set, it defaults to 0. This means the LR will approach zero. However, for very short fine-tuning runs or specific tasks, you might want the LR to decay to a small, non-zero value (e.g., 1e-6) to maintain some level of plasticity and prevent the model from becoming too rigid.

If your fine-tuning run crashes with a CUDA out of memory error, you’ll need to adjust your batch size, gradient accumulation steps, or consider model quantization.

Want structured learning?

Take the full Fine-tuning course →