Skip to content

Conversation

vidushi8
Copy link

Problem: Checkpoint saving is erroring out for torch_dist format because of a missing serialization_format argument, which was introduced recently in upstream pytorch.

Error:
[rank4]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py", line 706, in finalize_fn [rank4]: save_state_dict_async_finalize(*save_state_dict_ret) [rank4]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/strategies/state_dict_saver.py", line 144, in save_state_dict_async_finalize [rank4]: write_results = storage_writer.retrieve_write_results() [rank4]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/strategies/filesystem_async.py", line 337, in retrieve_write_results [rank4]: raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc [rank4]: RuntimeError: Worker failure: _write_item() missing 1 required positional argument: 'serialization_format' [rank5]: TypeError: _write_item() missing 1 required positional argument: 'serialization_format'

Solution: Add serialization_format argument in _write_item()

@vidushi8 vidushi8 marked this pull request as ready for review June 12, 2025 01:20
@vidushi8 vidushi8 requested a review from wenchenvincent June 16, 2025 19:23
# This PR https://github.com/pytorch/pytorch/pull/143359 introduced breaking change to saving checkpoints
# in torch_dist format. This is a workaround to fix the issue.
from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms
from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms, SerializationFormat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the pytorch version that requires this change? And would this change break the previous pytorch version?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This change has been introduced in pytorch 2.8.0 dev by this PR Add a param for save format in Storage Writer pytorch/pytorch#150025
  2. Yes, I think the introduced new param SerializationFormat was introduced starting latest 2.8 version, it may break for prior versions. We can add this change in an if block, which checks for torch version. Only pass SerializationFormat param if torch.version >= 2.8

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenchenvincent I have updated the logic to add an if condition to check for presence of serialization_format param. It should not break the previous pytorch version now. Can you review again?

# See LICENSE for license information.
#################################################################################
set -e
# set -e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vidushi8 I think there is a better way to do this. You mentioned that there is a command to detect for multi-node that would fail on single node. You can change it to the following:

<command to detect multi-node> || true

This way, it won't fail and break the run.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the multi-node NCCL_IB_HCA logic under the if condition NNODES>1

@zstreet87
Copy link
Collaborator

this looks good to me. @wenchenvincent since you've looked at everything before, I'll wait to see if you approve of the latest changes.

Copy link

Marking as stale. No activity in 60 days.

@github-actions github-actions bot added the stale label Sep 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants