Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3662,6 +3662,17 @@ def save_pretrained(
if self._tp_size is not None:
state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)

# if using tensor parallel we need to gather the tensors in state dict
gathered_state_dict = {}
for key, value in state_dict.items():
if hasattr(value, "_local_tensor"):
gathered_state_dict[key] = value.to_local().cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah using full_tensors will be better I think.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bursteratom and I found that full_tensor would hang here, not 100% sure why, but we could investigate more if manually redistributing doesn't work.

Copy link
Contributor Author

@bursteratom bursteratom Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SalmanMohammadi I wonder if it's related: pytorch/pytorch#115310

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr Should this be in transformers or is the preference that this sort of unsharding is in accelerate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian We have (will have) similar stuff in Accelerate for FSDP2, so possibly if we want to support both TP + FSDP2 on Accelerate side it'd need to be on both places. Though I remember full_tensor() working for me there, I might take a look at this too.

Copy link
Contributor

@kmehant kmehant Apr 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

value.to_local().cpu()

This would only return local to the rank shard of the tensor if the DTensor has a Shard placement which is highly likely for TP. Would not that mean the state dicts would be now different on each rank, isn't that a problem?

Copy link
Contributor

@S1ro1 S1ro1 Apr 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct. .to_local() only returns the local part of the tensor if it was sharded (most likely was as we're talking about TP), therefore this results for each process to have its own part. Possibility for why this hangs is because iirc full_tensor() requires communication and here only main process is running iirc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory will explode no? this should happen in the function that write the files to make sure you save bits by bits

else:
gathered_state_dict[key] = value.cpu()

del state_dict
state_dict = gathered_state_dict

if safe_serialization:
# TODO: fix safe_serialization for tied weights
# Safetensors does not allow tensor aliasing.
Expand Down