-
Notifications
You must be signed in to change notification settings - Fork 29
[Megatron-LM] Fix torch_dist ckpt format saving #83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: rocm_dev
Are you sure you want to change the base?
[Megatron-LM] Fix torch_dist ckpt format saving #83
Conversation
# This PR https://github.com/pytorch/pytorch/pull/143359 introduced breaking change to saving checkpoints | ||
# in torch_dist format. This is a workaround to fix the issue. | ||
from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms | ||
from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms, SerializationFormat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the pytorch version that requires this change? And would this change break the previous pytorch version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- This change has been introduced in pytorch 2.8.0 dev by this PR Add a param for save format in Storage Writer pytorch/pytorch#150025
- Yes, I think the introduced new param
SerializationFormat
was introduced starting latest 2.8 version, it may break for prior versions. We can add this change in an if block, which checks for torch version. Only passSerializationFormat
param if torch.version >= 2.8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wenchenvincent I have updated the logic to add an if condition to check for presence of serialization_format param. It should not break the previous pytorch version now. Can you review again?
# See LICENSE for license information. | ||
################################################################################# | ||
set -e | ||
# set -e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vidushi8 I think there is a better way to do this. You mentioned that there is a command to detect for multi-node that would fail on single node. You can change it to the following:
<command to detect multi-node> || true
This way, it won't fail and break the run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the multi-node NCCL_IB_HCA logic under the if condition NNODES>1
this looks good to me. @wenchenvincent since you've looked at everything before, I'll wait to see if you approve of the latest changes. |
Marking as stale. No activity in 60 days. |
Problem: Checkpoint saving is erroring out for torch_dist format because of a missing serialization_format argument, which was introduced recently in upstream pytorch.
Error:
[rank4]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/strategies/torch.py", line 706, in finalize_fn [rank4]: save_state_dict_async_finalize(*save_state_dict_ret) [rank4]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/strategies/state_dict_saver.py", line 144, in save_state_dict_async_finalize [rank4]: write_results = storage_writer.retrieve_write_results() [rank4]: File "/workspace/Megatron-LM/megatron/core/dist_checkpointing/strategies/filesystem_async.py", line 337, in retrieve_write_results [rank4]: raise RuntimeError(f'Worker failure: {write_results_or_exc}') from write_results_or_exc [rank4]: RuntimeError: Worker failure: _write_item() missing 1 required positional argument: 'serialization_format' [rank5]: TypeError: _write_item() missing 1 required positional argument: 'serialization_format'
Solution: Add serialization_format argument in
_write_item()