-
Notifications
You must be signed in to change notification settings - Fork 30.8k
Fix model saving bug post training with tensor parallel in Accelerate #36434
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2b9d0b4
to
df531d2
Compare
@kmehant Wondering what your thoughts are? |
cc @ArthurZucker who's also doing a big TP refactor right now! |
d4e4907
to
4460137
Compare
@ArthurZucker @kmehant Seems like I'm failing a couple of tests, but I'm struggling to find the root cause. Wondering if you two can kindly take a look? |
Same problem with me. T_T #36433 |
gathered_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if hasattr(value, "_local_tensor"): | ||
gathered_state_dict[key] = value.to_local().cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: we might want to do something closer to https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/distributed/tensor/_api.py#L572
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah using full_tensors
will be better I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bursteratom and I found that full_tensor
would hang here, not 100% sure why, but we could investigate more if manually redistributing doesn't work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SalmanMohammadi I wonder if it's related: pytorch/pytorch#115310
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@muellerzr Should this be in transformers or is the preference that this sort of unsharding is in accelerate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@winglian We have (will have) similar stuff in Accelerate for FSDP2, so possibly if we want to support both TP + FSDP2 on Accelerate side it'd need to be on both places. Though I remember full_tensor()
working for me there, I might take a look at this too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value.to_local().cpu()
This would only return local to the rank shard of the tensor if the DTensor has a Shard
placement which is highly likely for TP. Would not that mean the state dicts would be now different on each rank, isn't that a problem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is correct. .to_local()
only returns the local part of the tensor if it was sharded (most likely was as we're talking about TP), therefore this results for each process to have its own part. Possibility for why this hangs is because iirc full_tensor()
requires communication and here only main process is running iirc.
cc @muellerzr @SunMarc for accelerate as well |
45866d4
to
809275b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks ! Please add a test also
3b345fa
to
24a6c33
Compare
Would using full_tensors be a better approach? |
@machinelearningprodigy I initially used |
dedaa12
to
9708c36
Compare
cc @kwen2501 if you have any idea |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will merge this and I will let you do a follow-up PR when you will try to see with TP + FSDPv2 works in transformers @S1ro1 !
Thank you so much for taking a look at this @SunMarc !!! |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
e569f9a
to
9c31402
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should be careful and can save without exploding memory
gathered_state_dict = {} | ||
for key, value in state_dict.items(): | ||
if hasattr(value, "_local_tensor"): | ||
gathered_state_dict[key] = value.to_local().cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memory will explode no? this should happen in the function that write the files to make sure you save bits by bits
9c31402
to
2217e31
Compare
2217e31
to
ee271a0
Compare
re @S1ro1 might be good to fix this properly somehow |
Oh, this should actually be fixed by #37919 already. Should probably close then. |
SG ! |
What does this PR do?
Currently, attempting to save model after training with tensor parallel in Accelerate gives the
RuntimeError: Attempted to access the data pointer on an invalid python storage
, this is due to the state dict not properly gathered from the sharded tensors beforehand. This PR fixes the error, allowing for successful model saving.Big thank you to @SalmanMohammadi for the discussion!
Fixes # (issue)
#34194 (comment)
#36436
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.