System Info
Currently, attempting to save model after training with tensor parallel gives the RuntimeError: Attempted to access the data pointer on an invalid python storage
, this is due to the state dict not properly gathered from the sharded tensors beforehand.
Fix here: #36434

Who can help?
No response
Information
Tasks
Reproduction
Train the model with tensor parallelism by parsing tp_size >=2
into the trainer, make sure to specify output_dir
for the model saving directory.
Expected behavior
Model is saved upon completion of training.