Skip to content

ModelParallelStrategy fails with non-distributed checkpoint. #21357

@AShedko

Description

@AShedko

Bug description

When running a torch.compile-d model with ModelParallelStrategy, saving a non-distributed checkpoint fails.

That is due to a mismatch in fqn paths between get_optimizer_state_dict and rekey_optim_state_dict.
get_optimizer_state_dict replaces the _orig_mod with '' in tensor paths while rekey_optim_state_dict uses
_get_param_to_fqn, which does not.

Monkey-patching _get_param_to_fqn to replace the _orig_mod string in paths works as a band-aid solution.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

"""
Simple Lightning module compiled with torch.compile and trained with ModelParallelStrategy.
This script demonstrates:
- Using torch.compile with a Lightning module
- Training with ModelParallelStrategy (_save_distributed_checkpoint=False)
- Training on random data for 2 steps
- Saving the model
"""

from __future__ import annotations

import torch
import torch.nn as nn
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.utils.data import DataLoader, TensorDataset

class SimpleCompiledModule(LightningModule):
    """Simple neural network module that will be compiled with torch.compile."""

    def __init__(self, input_size: int = 128, hidden_size: int = 256, output_size: int = 64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
        )
        self.loss_fn = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def configure_model(self):
        """Configure the model with torch.compile."""
        if self.model is not None:
            # Compile the model for better performance
            self.model = torch.compile(self.model)


def create_random_dataloader(batch_size: int = 32, num_batches: int = 2) -> DataLoader:
    """Create a dataloader with random data."""
    # Generate random input and target data
    input_size = 128
    output_size = 64
    total_samples = batch_size * num_batches

    x = torch.randn(total_samples, input_size)
    y = torch.randn(total_samples, output_size)

    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)


def main():
    """Main function to train and save the compiled model."""
    print("=" * 80)
    print("Training Simple Compiled Model with ModelParallelStrategy")
    print("=" * 80)

    # Create the Lightning module
    model = SimpleCompiledModule()
    print("\n✓ Created SimpleCompiledModule")

    # Create random data loader
    train_loader = create_random_dataloader(batch_size=32, num_batches=2)
    print("✓ Created random data loader (2 batches)")

    # Create ModelParallelStrategy with _save_distributed_checkpoint=False
    strategy = ModelParallelStrategy(
        data_parallel_size=1,
        tensor_parallel_size=1,
        save_distributed_checkpoint=False,  # This is the key parameter
    )
    print("✓ Created ModelParallelStrategy with save_distributed_checkpoint=False")

    # Create trainer with the strategy
    trainer = Trainer(
        max_steps=2,  # Train for exactly 2 steps
        accelerator="auto",
        devices=1,
        strategy=strategy,
        enable_checkpointing=True,
        default_root_dir="./checkpoints",
        logger=False,  # Disable logger for simplicity
        enable_progress_bar=True,
        enable_model_summary=True,
    )
    print("✓ Created Trainer (max_steps=2)")

    # Train the model
    print("\nStarting training...")
    trainer.fit(model, train_loader)
    print("✓ Training completed (2 steps)")

    # Save the model
    save_path = "./checkpoints/compiled_model.ckpt"
    trainer.save_checkpoint(save_path)
    print(f"\n✓ Model saved to: {save_path}")

    print("\n" + "=" * 80)
    print("Training and saving completed successfully!")
    print("=" * 80)

    # Verify the checkpoint was created
    import os

    if os.path.exists(save_path):
        size_mb = os.path.getsize(save_path) / (1024 * 1024)
        print(f"\nCheckpoint file size: {size_mb:.2f} MB")
    else:
        print("\n⚠ Warning: Checkpoint file not found!")


if __name__ == "__main__":
    main()

Error messages and logs

Error message:

[rank0]:   File "/home/shedko/test_compiled_model_parallel.py", line 157, in <module>
[rank0]:     main()
[rank0]:   File "/home/shedko/test_compiled_model_parallel.py", line 134, in main
[rank0]:     trainer.fit(model, train_loader)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run
[rank0]:     results = self._run_stage()
[rank0]:               ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage
[rank0]:     self.fit_loop.run()
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 217, in run
[rank0]:     self.on_advance_end()
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 473, in on_advance_end
[rank0]:     call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 228, in _call_callback_hooks
[rank0]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 380, in on_train_epoch_end
[rank0]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 470, in _save_topk_checkpoint
[rank0]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 812, in _save_none_monitor_checkpoint
[rank0]:     self._save_checkpoint(trainer, filepath)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 473, in _save_checkpoint
[rank0]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1395, in save_checkpoint
[rank0]:     checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 456, in dump_checkpoint
[rank0]:     optimizer_state = trainer.strategy.optimizer_state(optimizer)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/lightning/pytorch/strategies/model_parallel.py", line 290, in optimizer_state
[rank0]:     state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1786, in rekey_optim_state_dict
[rank0]:     new_osd["state"] = {
[rank0]:                        ^
[rank0]:   File "/home/shedko/miniconda3/envs/prod_env/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1787, in <dictcomp>
[rank0]:     param_name_to_param_id[param_name]: param_state
[rank0]:     ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
[rank0]: KeyError: 'model.0.weight'

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A10G
    - NVIDIA A10G
    - NVIDIA A10G
    - NVIDIA A10G
    - available: True
    - version: 12.8
  • Lightning:
    - amzn-greenland-torchx-launcher: 1.0.21
    - fast-pytorch-kmeans: 0.2.2
    - lightning: 2.5.6
    - lightning-utilities: 0.15.2
    - lion-pytorch: 0.2.3
    - open-clip-torch: 3.2.0
    - pytorch-lightning: 2.5.6
    - pytorch-memlab: 0.3.0
    - pytorch-optimizer: 3.8.2
    - s3torchconnector: 1.4.3
    - s3torchconnectorclient: 1.4.3
    - torch: 2.8.0+cu128
    - torchaudio: 2.8.0+cu128
    - torchelastic: 0.2.2
    - torchmetrics: 1.8.2
    - torchvision: 0.23.0+cu128
    - torchx-nightly: 2025.11.6
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.13
    - release: 6.1.158-178.288.amzn2023.x86_64
    - version: Proposal for help #1 SMP PREEMPT_DYNAMIC Mon Nov 3 18:38:36 UTC 2025

More info

No response

cc @ethanwharris @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcheckpointingRelated to checkpointingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions