-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Added averaging script for torch dist #10834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: jubick1337 <[email protected]>
0d75d95
to
7b7eeef
Compare
import shutil | ||
import tarfile | ||
|
||
import numpy as np |
Check notice
Code scanning / CodeQL
Unused import Note
import tarfile | ||
|
||
import numpy as np | ||
import tensorstore # need to import it for bf16 support |
Check notice
Code scanning / CodeQL
Unused import Note
import tensorstore # need to import it for bf16 support | ||
import torch | ||
import torch.distributed as dist | ||
import zarr |
Check notice
Code scanning / CodeQL
Unused import Note
sharded_state_dict=averaged_state_dict, | ||
checkpoint_dir=ckpt_path, | ||
sharded_strategy=TorchDistSaveShardedStrategy(backend="torch_dist", version=1), | ||
validate_access_integrity=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it False?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know, it worked perfectly so far.
Should I change it to True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes please, I will help avoid potential issues in the future
save( | ||
sharded_state_dict=averaged_state_dict, | ||
checkpoint_dir=ckpt_path, | ||
sharded_strategy=TorchDistSaveShardedStrategy(backend="torch_dist", version=1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just pass sharded_strategy=torch_dist
and consequently make it configurable in the CLI (allowing e.g. 'zarr' if someone wants to have 'zarr' for some reason; it's deprecated but sometimes people fallback to it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a script for zarr format here.
I guess it can be merged into one, but so far, all of the formats had separate scripts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think having one script instead of 2 will be better going forward, but I'm not a user of those scripts so maybe there are some details I'm missing.
We don't have to remove the other one, we could just add an option to this one for a smooth transition
files = glob.glob(f"{args.untarred_nemo_folder}/*.model") + glob.glob( | ||
f"{args.untarred_nemo_folder}/*.vocab.json" | ||
) | ||
logging.info(f"Copying other files: {files}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to copy the original common.pt
(which contains all the PTL training related data like params, loops state etc.)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern shouldn't catch .pt
files but has to copy the tokenizer.
I'm not sure if common.pt
is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the averaged ckpts are used for inference/fine-tuning then probably not 👍
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
This PR was closed because it has been inactive for 7 days since being marked as stale. |
What does this PR do ?
Adding script to run checkpoint averaging for torch distributed checkpoint format.
Collection: [Note which collection this PR will affect]
Changelog
Usage
# Add a code snippet demonstrating how to use this
GitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information