Gradient accumulation lets you train models with effectively larger batch sizes than your GPU’s memory can hold by splitting a large batch into smaller micro-batches, processing them sequentially, and accumulating their gradients before performing a single weight update.
Let’s see it in action. Imagine we want to simulate a batch size of 256, but our GPU can only handle a batch size of 32.
import torch
import torch.nn as nn
import torch.optim as optim
# --- Model and Data Setup ---
model = nn.Linear(10, 2) # Simple model
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Simulate data
input_data = torch.randn(256, 10) # Effective batch size of 256
target_data = torch.randn(256, 2)
# --- Gradient Accumulation Setup ---
effective_batch_size = 256
micro_batch_size = 32
num_accumulation_steps = effective_batch_size // micro_batch_size # 256 // 32 = 8
# --- Training Loop ---
optimizer.zero_grad() # Zero gradients once at the start
for i in range(0, effective_batch_size, micro_batch_size):
# Extract micro-batch
inputs = input_data[i : i + micro_batch_size]
targets = target_data[i : i + micro_batch_size]
# Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
# Normalize loss to account for micro-batch size
loss = loss / num_accumulation_steps
# Backward pass (accumulate gradients)
loss.backward()
# Perform weight update only after accumulating gradients from all micro-batches
if (i // micro_batch_size) == num_accumulation_steps - 1:
optimizer.step()
optimizer.zero_grad() # Zero gradients for the next effective batch
print("Training step completed with gradient accumulation.")
This code snippet demonstrates the core idea: we loop through our data in chunks (micro_batch_size), perform a forward and backward pass for each chunk, and only call optimizer.step() after processing a predetermined number of these chunks (num_accumulation_steps). Crucially, we divide the loss by num_accumulation_steps before backward(). This ensures that the accumulated gradients reflect the average gradient over the entire effective_batch_size, maintaining the correct learning rate. Without this normalization, the gradients would be scaled up by num_accumulation_steps, leading to overly large updates and unstable training.
The fundamental problem gradient accumulation solves is the memory constraint of GPUs. Modern deep learning models, especially in fields like computer vision and natural language processing, often require very large batch sizes for stable and effective training. Large batch sizes can lead to better gradient estimates, faster convergence, and sometimes better generalization. However, a single batch of data, when passed through a deep neural network, consumes significant GPU memory for storing activations and intermediate computations. If your desired batch size exceeds the available GPU memory, you’ll encounter out-of-memory (OOM) errors.
Gradient accumulation circumvents this by breaking down the large batch into smaller "micro-batches." Each micro-batch is small enough to fit into GPU memory. The process involves:
- Forward Pass: Compute the output and loss for a single micro-batch.
- Backward Pass: Compute the gradients of the loss with respect to the model parameters. Instead of immediately updating the model weights, these gradients are added to a running sum of gradients.
- Repeat: Steps 1 and 2 are repeated for a specified number of micro-batches until the total number of samples processed equals the desired
effective_batch_size. - Weight Update: After processing all micro-batches for the effective batch, the accumulated gradients are scaled (usually by dividing by the number of accumulation steps) and then used to update the model’s weights via
optimizer.step(). The accumulated gradients are then zeroed out to start the process for the next effective batch.
The key is that optimizer.step() is called less frequently than for a standard, non-accumulated batch. If you have an effective_batch_size of 256 and a micro_batch_size of 32, you perform 8 forward/backward passes before a single optimizer.step(). This effectively simulates training with a batch size of 256, but the memory footprint at any given moment is only that of a batch size of 32.
The num_accumulation_steps is calculated as effective_batch_size // micro_batch_size. For example, if you want an effective batch size of 1024 and your GPU can handle a micro-batch size of 64, you’ll need 1024 // 64 = 16 accumulation steps.
The learning rate (lr) in your optimizer is critical. When using gradient accumulation, the effective batch size is micro_batch_size * num_accumulation_steps. If you were to train with this effective_batch_size directly, you might adjust your learning rate. A common heuristic is to scale the learning rate linearly with the batch size. Therefore, if you increase your effective batch size by a factor of N compared to what you might have trained with smaller batches, you might also increase your learning rate by a factor of N. However, many practitioners find it sufficient to keep the learning rate the same as they would for the micro_batch_size and rely on the gradient accumulation to provide a more stable estimate. The division of the loss by num_accumulation_steps before backward() is essential regardless of learning rate scaling to prevent gradients from becoming too large.
The memory savings are direct: the memory required is proportional to micro_batch_size, not effective_batch_size. This allows training very large models or using very large effective batch sizes on hardware with limited VRAM. For instance, a researcher training a massive language model might only have a few A100 GPUs, each with 40GB or 80GB of VRAM. Without gradient accumulation, they might be limited to a micro-batch size of 1 or 2. Gradient accumulation allows them to simulate batch sizes of 32, 64, or even more, which is crucial for stable training of such models.
When implementing gradient accumulation, you must ensure that optimizer.zero_grad() is called only once at the beginning of processing an entire effective batch, and then again only after optimizer.step() is called. If you call optimizer.zero_grad() after each micro-batch’s backward pass, you will discard the accumulated gradients, defeating the purpose of the technique.
One subtle point is how the loss is handled. The loss calculated on a micro-batch is not the final loss for the effective batch. To ensure the gradients are correctly scaled, you must divide the loss by num_accumulation_steps before calling loss.backward(). This effectively averages the loss over the micro-batches, and when gradients are computed and accumulated, they represent the average gradient over the entire effective batch, aligning with the expected gradient magnitude for that batch size.
The next hurdle you’ll encounter is learning rate scheduling and its interaction with gradient accumulation, as determining the optimal learning rate schedule for an effectively large batch size can be complex.