-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
When running a torch.compile-d model with ModelParallelStrategy, saving a non-distributed checkpoint fails.
That is due to a mismatch in fqn paths between get_optimizer_state_dict and rekey_optim_state_dict.
get_optimizer_state_dict replaces the _orig_mod with '' in tensor paths while rekey_optim_state_dict uses
_get_param_to_fqn, which does not.
Monkey-patching _get_param_to_fqn to replace the _orig_mod string in paths works as a band-aid solution.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
"""
Simple Lightning module compiled with torch.compile and trained with ModelParallelStrategy.
This script demonstrates:
- Using torch.compile with a Lightning module
- Training with ModelParallelStrategy (_save_distributed_checkpoint=False)
- Training on random data for 2 steps
- Saving the model
"""
from __future__ import annotations
import torch
import torch.nn as nn
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.utils.data import DataLoader, TensorDataset
class SimpleCompiledModule(LightningModule):
"""Simple neural network module that will be compiled with torch.compile."""
def __init__(self, input_size: int = 128, hidden_size: int = 256, output_size: int = 64):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
)
self.loss_fn = nn.MSELoss()
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def configure_model(self):
"""Configure the model with torch.compile."""
if self.model is not None:
# Compile the model for better performance
self.model = torch.compile(self.model)
def create_random_dataloader(batch_size: int = 32, num_batches: int = 2) -> DataLoader:
"""Create a dataloader with random data."""
# Generate random input and target data
input_size = 128
output_size = 64
total_samples = batch_size * num_batches
x = torch.randn(total_samples, input_size)
y = torch.randn(total_samples, output_size)
dataset = TensorDataset(x, y)
return DataLoader(dataset, batch_size=batch_size, shuffle=False)
def main():
"""Main function to train and save the compiled model."""
print("=" * 80)
print("Training Simple Compiled Model with ModelParallelStrategy")
print("=" * 80)
# Create the Lightning module
model = SimpleCompiledModule()
print("\n✓ Created SimpleCompiledModule")
# Create random data loader
train_loader = create_random_dataloader(batch_size=32, num_batches=2)
print("✓ Created random data loader (2 batches)")
# Create ModelParallelStrategy with _save_distributed_checkpoint=False
strategy = ModelParallelStrategy(
data_parallel_size=1,
tensor_parallel_size=1,
save_distributed_checkpoint=False, # This is the key parameter
)
print("✓ Created ModelParallelStrategy with save_distributed_checkpoint=False")
# Create trainer with the strategy
trainer = Trainer(
max_steps=2, # Train for exactly 2 steps
accelerator="auto",
devices=1,
strategy=strategy,
enable_checkpointing=True,
default_root_dir="./checkpoints",
logger=False, # Disable logger for simplicity
enable_progress_bar=True,
enable_model_summary=True,
)
print("✓ Created Trainer (max_steps=2)")
# Train the model
print("\nStarting training...")
trainer.fit(model, train_loader)
print("✓ Training completed (2 steps)")
# Save the model
save_path = "./checkpoints/compiled_model.ckpt"
trainer.save_checkpoint(save_path)
print(f"\n✓ Model saved to: {save_path}")
print("\n" + "=" * 80)
print("Training and saving completed successfully!")
print("=" * 80)
# Verify the checkpoint was created
import os
if os.path.exists(save_path):
size_mb = os.path.getsize(save_path) / (1024 * 1024)
print(f"\nCheckpoint file size: {size_mb:.2f} MB")
else:
print("\n⚠ Warning: Checkpoint file not found!")
if __name__ == "__main__":
main()Error messages and logs
Error message:
[rank0]: File "/home/shedko/test_compiled_model_parallel.py", line 157, in <module>
[rank0]: main()
[rank0]: File "/home/shedko/test_compiled_model_parallel.py", line 134, in main
[rank0]: trainer.fit(model, train_loader)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit
[rank0]: call._call_and_handle_interrupt(
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank0]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]: return function(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl
[rank0]: self._run(model, ckpt_path=ckpt_path)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run
[rank0]: results = self._run_stage()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage
[rank0]: self.fit_loop.run()
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 217, in run
[rank0]: self.on_advance_end()
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 473, in on_advance_end
[rank0]: call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 228, in _call_callback_hooks
[rank0]: fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 380, in on_train_epoch_end
[rank0]: self._save_topk_checkpoint(trainer, monitor_candidates)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 470, in _save_topk_checkpoint
[rank0]: self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 812, in _save_none_monitor_checkpoint
[rank0]: self._save_checkpoint(trainer, filepath)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 473, in _save_checkpoint
[rank0]: trainer.save_checkpoint(filepath, self.save_weights_only)
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1395, in save_checkpoint
[rank0]: checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 456, in dump_checkpoint
[rank0]: optimizer_state = trainer.strategy.optimizer_state(optimizer)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/strategies/model_parallel.py", line 290, in optimizer_state
[rank0]: state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1786, in rekey_optim_state_dict
[rank0]: new_osd["state"] = {
[rank0]: ^
[rank0]: File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1787, in <dictcomp>
[rank0]: param_name_to_param_id[param_name]: param_state
[rank0]: ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
[rank0]: KeyError: 'model.0.weight'Environment
Current environment
- CUDA:
- GPU:
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- NVIDIA A10G
- available: True
- version: 12.8 - Lightning:
- amzn-greenland-torchx-launcher: 1.0.21
- fast-pytorch-kmeans: 0.2.2
- lightning: 2.5.6
- lightning-utilities: 0.15.2
- lion-pytorch: 0.2.3
- open-clip-torch: 3.2.0
- pytorch-lightning: 2.5.6
- pytorch-memlab: 0.3.0
- pytorch-optimizer: 3.8.2
- s3torchconnector: 1.4.3
- s3torchconnectorclient: 1.4.3
- torch: 2.8.0+cu128
- torchaudio: 2.8.0+cu128
- torchelastic: 0.2.2
- torchmetrics: 1.8.2
- torchvision: 0.23.0+cu128
- torchx-nightly: 2025.11.6 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.13
- release: 6.1.158-178.288.amzn2023.x86_64
- version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Nov 3 18:38:36 UTC 2025
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x