Skip to content

Commit 0208d41

Browse files
Yard1elusenji
authored andcommitted
Fix distributed_concat with scalar tensor (huggingface#16963)
* Fix `distributed_concat` with scalar tensor * Update trainer_pt_utils.py
1 parent a031015 commit 0208d41

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/trainer_pt_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
159159
try:
160160
if isinstance(tensor, (tuple, list)):
161161
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
162+
if len(tensor.shape) <= 0:
163+
tensor = tensor[None]
162164
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
163-
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
164165
dist.all_gather(output_tensors, tensor)
165166
concat = torch.cat(output_tensors, dim=0)
166167

0 commit comments

Comments
 (0)