Skip to content

Conversation

bursteratom
Copy link
Contributor

@bursteratom bursteratom commented Feb 26, 2025

What does this PR do?

Currently, attempting to save model after training with tensor parallel in Accelerate gives the RuntimeError: Attempted to access the data pointer on an invalid python storage, this is due to the state dict not properly gathered from the sharded tensors beforehand. This PR fixes the error, allowing for successful model saving.

Big thank you to @SalmanMohammadi for the discussion!

tp_save_error

Fixes # (issue)
#34194 (comment)
#36436

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@bursteratom
Copy link
Contributor Author

@kmehant Wondering what your thoughts are?

@Rocketknight1
Copy link
Member

cc @ArthurZucker who's also doing a big TP refactor right now!

@bursteratom bursteratom force-pushed the tp-model_saving-fix branch 3 times, most recently from d4e4907 to 4460137 Compare February 28, 2025 17:31
@bursteratom
Copy link
Contributor Author

bursteratom commented Feb 28, 2025

@ArthurZucker @kmehant Seems like I'm failing a couple of tests, but I'm struggling to find the root cause. Wondering if you two can kindly take a look?

@bursteratom bursteratom changed the title Fix model saving bug post training with tensor parallel Fix model saving bug post training with tensor parallel in Accelerate Feb 28, 2025
@ShaohonChen
Copy link
Contributor

Same problem with me. T_T #36433

gathered_state_dict = {}
for key, value in state_dict.items():
if hasattr(value, "_local_tensor"):
gathered_state_dict[key] = value.to_local().cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah using full_tensors will be better I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bursteratom and I found that full_tensor would hang here, not 100% sure why, but we could investigate more if manually redistributing doesn't work.

Copy link
Contributor Author

@bursteratom bursteratom Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi I wonder if it's related: pytorch/pytorch#115310

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr Should this be in transformers or is the preference that this sort of unsharding is in accelerate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian We have (will have) similar stuff in Accelerate for FSDP2, so possibly if we want to support both TP + FSDP2 on Accelerate side it'd need to be on both places. Though I remember full_tensor() working for me there, I might take a look at this too.

Copy link
Contributor

@kmehant kmehant Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value.to_local().cpu()

This would only return local to the rank shard of the tensor if the DTensor has a Shard placement which is highly likely for TP. Would not that mean the state dicts would be now different on each rank, isn't that a problem?

Copy link
Contributor

@S1ro1 S1ro1 Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct. .to_local() only returns the local part of the tensor if it was sharded (most likely was as we're talking about TP), therefore this results for each process to have its own part. Possibility for why this hangs is because iirc full_tensor() requires communication and here only main process is running iirc.

@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc for accelerate as well

@bursteratom bursteratom force-pushed the tp-model_saving-fix branch from 45866d4 to 809275b Compare March 3, 2025 21:52
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Please add a test also

@bursteratom bursteratom force-pushed the tp-model_saving-fix branch 2 times, most recently from 3b345fa to 24a6c33 Compare March 4, 2025 14:29
@machinelearningprodigy
Copy link

Would using full_tensors be a better approach?

@bursteratom
Copy link
Contributor Author

bursteratom commented Mar 4, 2025

@machinelearningprodigy I initially used full_tensor() but for some reason it was hanging/incredibly slow, I can do some code tests on my end to figure out why that is the case

@bursteratom bursteratom force-pushed the tp-model_saving-fix branch 2 times, most recently from dedaa12 to 9708c36 Compare March 4, 2025 16:45
@SunMarc
Copy link
Member

SunMarc commented Mar 4, 2025

cc @kwen2501 if you have any idea

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will merge this and I will let you do a follow-up PR when you will try to see with TP + FSDPv2 works in transformers @S1ro1 !

@bursteratom
Copy link
Contributor Author

Thank you so much for taking a look at this @SunMarc !!!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bursteratom bursteratom force-pushed the tp-model_saving-fix branch 5 times, most recently from e569f9a to 9c31402 Compare April 7, 2025 15:14
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should be careful and can save without exploding memory

gathered_state_dict = {}
for key, value in state_dict.items():
if hasattr(value, "_local_tensor"):
gathered_state_dict[key] = value.to_local().cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory will explode no? this should happen in the function that write the files to make sure you save bits by bits

@bursteratom bursteratom force-pushed the tp-model_saving-fix branch from 2217e31 to ee271a0 Compare June 19, 2025 17:39
@SunMarc
Copy link
Member

SunMarc commented Jun 20, 2025

re @S1ro1 might be good to fix this properly somehow

@S1ro1
Copy link
Contributor

S1ro1 commented Jun 20, 2025

Oh, this should actually be fixed by #37919 already. Should probably close then.

@SunMarc
Copy link
Member

SunMarc commented Jun 20, 2025

SG !

@SunMarc SunMarc closed this Jun 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.