Skip to content

Commit 5896b3e

Browse files
authored
Fix distributed_concat with scalar tensor (#16963)
* Fix `distributed_concat` with scalar tensor * Update trainer_pt_utils.py
1 parent 084c38c commit 5896b3e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/transformers/trainer_pt_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,9 @@ def distributed_concat(tensor: Any, num_total_examples: Optional[int] = None) ->
159159
try:
160160
if isinstance(tensor, (tuple, list)):
161161
return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
162+
if len(tensor.shape) <= 0:
163+
tensor = tensor[None]
162164
output_tensors = [tensor.clone() for _ in range(dist.get_world_size())]
163-
output_tensors = [t if len(t.shape) > 0 else t[None] for t in output_tensors]
164165
dist.all_gather(output_tensors, tensor)
165166
concat = torch.cat(output_tensors, dim=0)
166167

0 commit comments

Comments
 (0)