Skip to content

Commit e1c6079

Browse files
authored
[Feature] Add collect_results support for Ascend NPU (#1309)
1 parent 19ab172 commit e1c6079

File tree

1 file changed

+31
-25
lines changed

1 file changed

+31
-25
lines changed

mmengine/dist/dist.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -898,10 +898,11 @@ def collect_results(results: list,
898898
object.
899899
size (int): Size of the results, commonly equal to length of
900900
the results.
901-
device (str): Device name. Optional values are 'cpu' and 'gpu'.
901+
device (str): Device name. Optional values are 'cpu', 'gpu' or 'npu'.
902902
tmpdir (str | None): Temporal directory for collected results to
903903
store. If set to None, it will create a temporal directory for it.
904-
``tmpdir`` should be None when device is 'gpu'. Defaults to None.
904+
``tmpdir`` should be None when device is 'gpu' or 'npu'.
905+
Defaults to None.
905906
906907
Returns:
907908
list or None: The collected results.
@@ -920,13 +921,13 @@ def collect_results(results: list,
920921
['foo', 24, {1: 2}, {'a': 'b'}] # rank 0
921922
None # rank 1
922923
"""
923-
if device not in ['gpu', 'cpu']:
924+
if device not in ['gpu', 'cpu', 'npu']:
924925
raise NotImplementedError(
925-
f"device must be 'cpu' or 'gpu', but got {device}")
926+
f"device must be 'cpu' , 'gpu' or 'npu', but got {device}")
926927

927-
if device == 'gpu':
928-
assert tmpdir is None, 'tmpdir should be None when device is "gpu"'
929-
return collect_results_gpu(results, size)
928+
if device == 'gpu' or device == 'npu':
929+
assert tmpdir is None, f'tmpdir should be None when device is {device}'
930+
return _collect_results_device(results, size)
930931
else:
931932
return collect_results_cpu(results, size, tmpdir)
932933

@@ -1018,6 +1019,28 @@ def collect_results_cpu(result_part: list,
10181019
return ordered_results
10191020

10201021

1022+
def _collect_results_device(result_part: list, size: int) -> Optional[list]:
1023+
"""Collect results under gpu or npu mode."""
1024+
rank, world_size = get_dist_info()
1025+
if world_size == 1:
1026+
return result_part[:size]
1027+
1028+
# gather all result part. Note that NCCL does not support gather so use
1029+
# all_gather_object instead.
1030+
part_list = all_gather_object(result_part)
1031+
1032+
if rank == 0:
1033+
# sort the results
1034+
ordered_results = []
1035+
for res in zip(*part_list):
1036+
ordered_results.extend(list(res))
1037+
# the dataloader may pad some samples
1038+
ordered_results = ordered_results[:size]
1039+
return ordered_results
1040+
else:
1041+
return None
1042+
1043+
10211044
def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
10221045
"""Collect results under gpu mode.
10231046
@@ -1048,24 +1071,7 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
10481071
['foo', 24, {1: 2}, {'a': 'b'}] # rank 0
10491072
None # rank 1
10501073
"""
1051-
rank, world_size = get_dist_info()
1052-
if world_size == 1:
1053-
return result_part[:size]
1054-
1055-
# gather all result part. Note that NCCL does not support gather so use
1056-
# all_gather_object instead.
1057-
part_list = all_gather_object(result_part)
1058-
1059-
if rank == 0:
1060-
# sort the results
1061-
ordered_results = []
1062-
for res in zip(*part_list):
1063-
ordered_results.extend(list(res))
1064-
# the dataloader may pad some samples
1065-
ordered_results = ordered_results[:size]
1066-
return ordered_results
1067-
else:
1068-
return None
1074+
return _collect_results_device(result_part, size)
10691075

10701076

10711077
def _all_reduce_coalesced(tensors: List[torch.Tensor],

0 commit comments

Comments
 (0)