FSDP, or Fully Sharded Data Parallelism, is a PyTorch feature that can shard your model, gradients, and optimizer states across multiple GPUs, allowing you to train models that are too large to fit on a single GPU.

Here’s a look at how FSDP works and how to configure it for distributed fine-tuning.

The Problem FSDP Solves

Traditionally, when fine-tuning large language models (LLMs) on multiple GPUs, you might use Data Parallelism (DP) or Distributed Data Parallelism (DDP). With DP/DDP, each GPU holds a full replica of the model, gradients, and optimizer states. This works well until the model, its gradients, or the optimizer states exceed the memory capacity of a single GPU. For example, a 70B parameter model can easily require over 140GB of memory just for the model weights (FP16), and even more for gradients and optimizer states.

FSDP addresses this by sharding these components across the available GPUs. Instead of each GPU having a full copy, each GPU only holds a shard of the model parameters, gradients, and optimizer states. During the forward and backward passes, FSDP dynamically gathers the necessary parameters for computation and then discards them, keeping only the local shard. This dramatically reduces the memory footprint per GPU, enabling the fine-tuning of massive models on clusters of GPUs.

Configuring FSDP

Let’s walk through a practical example of configuring FSDP for fine-tuning a model.

First, you need to initialize the distributed environment. This is typically done using torch.distributed.init_process_group.

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from transformers import AutoModelForCausalLM, AutoTokenizer
import functools

# Initialize the distributed environment
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

Next, you’ll load your model and tokenizer. For this example, we’ll use a smaller model, but the principles apply to larger ones.

model_name = "facebook/opt-125m" # Replace with your model
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Now, the core of FSDP configuration involves wrapping your model. You need to decide which parts of your model to shard. FSDP provides a FullyShardedDataParallel class for this.

# Define the FSDP wrapping policy
# This example shards layers based on their size, which is often a good
# starting point for transformer models. More complex policies can be defined.
def_auto_wrap_policy = functools.partial(
    size_based_auto_wrap_policy, min_num_params=1000
)

# Wrap the model with FSDP
model = FSDP(
    model,
    auto_wrap_policy=def_auto_wrap_policy,
    # Other FSDP parameters can be configured here, e.g.,
    # mixed_precision=MixedPrecision(fp32_reduce_scatter=True, bf16_full_precision=False),
    # sharding_strategy=ShardingStrategy.FULL_SHARD,
    # cpu_offload=CPUOffload(offload_params=True),
)

The auto_wrap_policy is crucial. size_based_auto_wrap_policy is a convenient way to automatically wrap submodules whose total parameter count exceeds a threshold. This means FSDP will shard parameters within these wrapped submodules. For transformer models, it’s common to wrap entire transformer blocks.

You can also specify sharding_strategy. The most common is FULL_SHARD, which shards parameters, gradients, and optimizer states. Other strategies like SHARD_GRAD_OP (shards gradients and optimizer states) or NO_SHARD (behaves like DDP) are also available.

mixed_precision allows you to control the precision of computations and storage, which can save memory and speed up training. For instance, setting mixed_precision=MixedPrecision(fp32_reduce_scatter=True, bf16_full_precision=False) enables BF16 for computations and FP32 for gradient reduction, which is a good balance for many GPUs.

cpu_offload allows offloading parameters to CPU RAM when they are not needed for computation, further reducing GPU memory usage, though at the cost of slower training.

The Training Loop

The training loop itself looks very similar to a standard PyTorch training loop. FSDP handles the gradient synchronization and parameter updates automatically.

from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

# Dummy dataset and dataloader for demonstration
class DummyDataset(Dataset):
    def __init__(self, num_samples=100, seq_len=512):
        self.num_samples = num_samples
        self.seq_len = seq_len

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        input_ids = torch.randint(0, tokenizer.vocab_size, (self.seq_len,), dtype=torch.long)
        labels = input_ids.clone()
        return {"input_ids": input_ids, "attention_mask": torch.ones_like(input_ids), "labels": labels}

dataset = DummyDataset()
# For distributed training, use DistributedSampler
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
dataloader = DataLoader(dataset, batch_size=8, sampler=sampler) # Adjust batch_size

optimizer = AdamW(model.parameters(), lr=1e-5)

# FSDP requires a specific way to handle the optimizer state
# This is often done implicitly if you pass model.parameters() to the optimizer
# but explicit wrapping can be beneficial for more control.
# For simpler cases, the above is often sufficient.

num_epochs = 1
for epoch in range(num_epochs):
    sampler.set_epoch(epoch) # Important for DistributedSampler
    for batch in dataloader:
        input_ids = batch["input_ids"].to(local_rank)
        attention_mask = batch["attention_mask"].to(local_rank)
        labels = batch["labels"].to(local_rank)

        optimizer.zero_grad()

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        # FSDP handles the backward pass synchronization
        loss.backward()
        optimizer.step()

        if local_rank == 0:
            print(f"Epoch {epoch}, Loss: {loss.item()}")

dist.destroy_process_group()

The key takeaway here is that loss.backward() and optimizer.step() work as usual. FSDP intercepts the gradient computation and parameter updates to manage the sharded states across processes.

The Surprising Efficiency of FULL_SHARD

Many people instinctively shy away from FULL_SHARD thinking it’s the most complex or slowest strategy. However, for very large models, FULL_SHARD is often the only strategy that fits the model into memory, and its performance can be surprisingly good because it minimizes redundant data movement. Instead of each GPU holding a full model and then only communicating gradients, FULL_SHARD ensures each GPU only ever holds its piece of the model. During the forward and backward passes, it orchestrates communication to bring together only the necessary parameters for the current computation step, then discards them. This aggressive sharding means that while there’s more communication during the computation, the total memory pressure is drastically reduced, often leading to better overall throughput than less sharded strategies when memory is the bottleneck.

Next Steps

After successfully configuring FSDP for basic fine-tuning, the next logical step is to explore more advanced sharding strategies and mixed-precision settings to optimize memory usage and training speed for your specific hardware and model size.

Want structured learning?

Take the full Fine-tuning course →