A lightweight, educational implementation of distributed deep learning strategies including DDP, ZeRO-3, and FSDP (Fully Sharded Data Parallel). This project provides clean, understandable implementations of modern distributed training techniques with a focus on clarity and learning.
- Three Distributed Strategies: Complete implementations of DDP, ZeRO-3, and FSDP
- Educational Focus: Clean, well-documented code designed for learning and understanding
- Production-Ready: Efficient implementations with proper memory management and communication optimization
- Modular Design: Easy to extend and customize for different use cases
- CUDA Optimized: Built-in support for GPU acceleration and distributed computing
- PyTorch Native: Leverages PyTorch's distributed primitives for maximum compatibility
Strategy | Sharding Type | Memory Distribution | Communication Pattern | Best For |
---|---|---|---|---|
DDP | Gradient Only | Full model replication | All-reduce gradients | Small to medium models |
ZeRO-3 | Inter-tensor | Whole tensors per rank | Broadcast parameters | Large models with uneven layers |
FSDP | Intra-tensor | Tensor slices per rank | All-gather/Reduce-scatter | Very large models, even distribution |
- Memory: Each rank holds a full copy of the model
- Communication: Gradients are all-reduced across ranks
- Overhead: Minimal, suitable for smaller models
- Memory: Parameters distributed across ranks (inter-tensor sharding)
- Communication: Broadcast parameters from owner, reduce gradients to owner
- Overhead: Dynamic parameter synchronization, good for heterogeneous workloads
- Memory: Each parameter tensor split along dim-0 across all ranks
- Communication: All-gather for forward/backward, reduce-scatter for gradients
- Overhead: Balanced load distribution, optimal for very large models
- Python 3.8+
- PyTorch 2.0+
- CUDA 11.0+ (for GPU support)
- NCCL (for multi-GPU communication)
pip install torch torchvision torchaudio
pip install tqdm
git clone <repository-url>
cd Tiny-FSDP
All training scripts support the following parameters:
Parameter | Options | Default | Description |
---|---|---|---|
--model |
gpt2 , gpt2_medium , gpt2_large , gpt2_xl |
gpt2 |
Model size |
--lr |
float | 1e-5 |
Learning rate |
--steps |
int | 100 |
Number of training steps |
--weight_decay |
float | 1e-1 |
Weight decay for optimizer |
import torch
from tiny_fsdp.core import SGD, AdamW
from example.model import GPT2Model, GPTConfigs
model = GPT2Model(GPTConfigs.gpt2)
optimizer = AdamW(model.named_parameters(), lr=1e-4)
# Standard training loop
for batch in dataloader:
output = model(batch)
loss = criterion(output, target)
loss.backward()
optimizer.step()
torchrun --nproc_per_node=2 example/ddp/train.py
# With custom parameters
torchrun --nproc_per_node=2 example/ddp/train.py --model gpt2_medium --lr 2e-5 --steps 200
torchrun --nproc_per_node=2 example/zero3/train.py
# With custom parameters
torchrun --nproc_per_node=2 example/zero3/train.py --model gpt2_large --lr 1e-4 --steps 50
torchrun --nproc_per_node=2 example/fsdp/train.py
# With custom parameters
torchrun --nproc_per_node=2 example/fsdp/train.py --model gpt2_xl --lr 5e-6 --steps 500
python example/single_device/train.py --model gpt2 --lr 1e-4 --steps 100
import torch.distributed as dist
from tiny_fsdp.core import DDP, DDPAdamW
# Initialize distributed training
dist.init_process_group(backend='nccl')
# Wrap model with DDP
model = GPT2Model(config)
model = DDP(model)
# Use DDP-aware optimizer
optimizer = DDPAdamW(model.named_parameters(), lr=1e-4)
# Training loop
model.require_backward_grad_sync = True # Enable gradient sync
for batch in dataloader:
loss = model(batch)
loss.backward() # Gradients automatically all-reduced
optimizer.step()
from tiny_fsdp.core import Zero3, Zero3AdamW, zero3_partition_tensors
# Partition model parameters across ranks
with torch.device('meta'):
model = GPT2Model(config)
parts, _ = zero3_partition_tensors(
OrderedDict(model.named_parameters()),
ranks_map=[f"cuda:{i}" for i in range(world_size)],
evenness_priority=0
)
# Wrap with Zero3
model = GPT2Model(config)
model = Zero3(model, parts)
# Use Zero3-aware optimizer
optimizer = Zero3AdamW(
model.module.named_parameters(),
lr=1e-4,
param_part_table=parts,
ranks_map=[f"cuda:{i}" for i in range(world_size)]
)
from tiny_fsdp.core import FSDP, FSDPAdamW
# Wrap model with FSDP (automatic parameter sharding)
model = GPT2Model(config)
model = FSDP(model, world_size=world_size, rank=rank)
# Use FSDP-aware optimizer (works with sharded parameters)
optimizer = FSDPAdamW(model.named_parameters(), lr=1e-4)
# Training loop
for batch in dataloader:
output = model(batch) # Parameters auto-gathered
loss = criterion(output, target)
loss.backward() # Gradients auto-scattered
optimizer.step() # Update local shards
tiny_fsdp/core/
βββ ddp/ # Distributed Data Parallel
βββ zero3/ # ZeRO-3 Implementation
βββ fsdp/ # Fully Sharded Data Parallel
βββ module/ # Base modules (Linear, LayerNorm, Embedding)
βββ optim/ # Optimizers (SGD, AdamW)
βββ utils/ # Utilities and helpers
Each strategy implements:
- Module Wrappers: Custom Linear, LayerNorm, Embedding layers
- Model Wrapper: High-level model container
- Optimizers: Strategy-specific parameter update logic
- Communication: Efficient parameter and gradient synchronization
For a model with P parameters across N ranks:
Strategy | Parameters/Rank | Gradients/Rank | Optimizer States/Rank |
---|---|---|---|
DDP | P | P | P |
ZeRO-3 | P/N | P/N | P/N |
FSDP | P/N | P/N | P/N |
Strategy | Forward Pass | Backward Pass | Optimizer Step |
---|---|---|---|
DDP | None | All-reduce(P) | None |
ZeRO-3 | Broadcast(P) | Reduce(P) | None |
FSDP | All-gather(P) | Reduce-scatter(P) | None |
To add support for new PyTorch modules:
# 1. Implement the module wrapper
class MyCustomModule(base_module.MyCustomModule):
def forward_callback(self, ctx, *args):
# Custom forward logic with parameter sync
pass
def backward_callback(self, ctx, grad_output):
# Custom backward logic with gradient handling
pass
# 2. Register in the strategy wrapper
_supported_modules = {
nn.MyCustomModule: MyCustomModule,
# ... other modules
}
# Enable gradient checkpointing
model.gradient_checkpointing = True
# Tune communication overlap
model.enable_async_communication = True
# Configure precision
model.half() # Use FP16 for memory efficiency
Tested on GPT-2 (117M parameters) across 2 RTX 4090s:
Strategy | Memory/GPU | Training Speed | Convergence |
---|---|---|---|
DDP | ~2.1GB | 4.9 it/s | Baseline |
ZeRO-3 | ~1.8GB | 4.2 it/s | Same as DDP |
FSDP | ~1.8GB | 4.5 it/s | Same as DDP |
Results may vary based on model architecture, batch size, and hardware configuration.
We welcome contributions! Please see our contributing guidelines:
- Fork the repository
- Create a feature branch
- Add tests for new functionality
- Ensure all tests pass
- Submit a pull request
pip install -e .
pip install pytest black flake8
pytest tests/
This implementation is designed for learning. Key educational features:
- Clear Code: Well-commented, readable implementations
- Minimal Dependencies: Focus on core concepts without complexity
- Comparative Analysis: Easy to compare different strategies
- Documentation: Comprehensive docs and examples
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.