Skip to content

Commit 2e69c78

Browse files
committed
[Feature]Add collect_results for Ascend NPU
1 parent 488fddc commit 2e69c78

File tree

2 files changed

+66
-10
lines changed

2 files changed

+66
-10
lines changed

mmengine/dist/__init__.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from .dist import (all_gather_object, all_reduce, all_gather, all_reduce_dict,
33
collect_results, gather, broadcast, gather_object,
44
sync_random_seed, broadcast_object_list,
5-
collect_results_cpu, collect_results_gpu, all_reduce_params)
5+
collect_results_cpu, collect_results_gpu,
6+
collect_results_npu, all_reduce_params)
67
from .utils import (get_dist_info, init_dist, init_local_group, get_backend,
78
get_world_size, get_rank, get_local_size, get_local_rank,
89
is_main_process, master_only, barrier, get_local_group,
@@ -11,11 +12,12 @@
1112

1213
__all__ = [
1314
'all_gather_object', 'all_reduce', 'all_gather', 'all_reduce_dict',
14-
'collect_results', 'collect_results_cpu', 'collect_results_gpu', 'gather',
15-
'broadcast', 'gather_object', 'sync_random_seed', 'broadcast_object_list',
16-
'get_dist_info', 'init_dist', 'init_local_group', 'get_backend',
17-
'get_world_size', 'get_rank', 'get_local_size', 'get_local_group',
18-
'get_local_rank', 'is_main_process', 'master_only', 'barrier',
19-
'is_distributed', 'get_default_group', 'all_reduce_params',
20-
'get_data_device', 'get_comm_device', 'cast_data_device', 'infer_launcher'
15+
'collect_results', 'collect_results_cpu', 'collect_results_gpu',
16+
'collect_results_npu', 'gather', 'broadcast', 'gather_object',
17+
'sync_random_seed', 'broadcast_object_list', 'get_dist_info', 'init_dist',
18+
'init_local_group', 'get_backend', 'get_world_size', 'get_rank',
19+
'get_local_size', 'get_local_group', 'get_local_rank', 'is_main_process',
20+
'master_only', 'barrier', 'is_distributed', 'get_default_group',
21+
'all_reduce_params', 'get_data_device', 'get_comm_device',
22+
'cast_data_device', 'infer_launcher'
2123
]

mmengine/dist/dist.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -898,10 +898,11 @@ def collect_results(results: list,
898898
object.
899899
size (int): Size of the results, commonly equal to length of
900900
the results.
901-
device (str): Device name. Optional values are 'cpu' and 'gpu'.
901+
device (str): Device name. Optional values are 'cpu', 'gpu' or 'npu'.
902902
tmpdir (str | None): Temporal directory for collected results to
903903
store. If set to None, it will create a temporal directory for it.
904-
``tmpdir`` should be None when device is 'gpu'. Defaults to None.
904+
``tmpdir`` should be None when device is 'gpu' and 'npu'.
905+
Defaults to None.
905906
906907
Returns:
907908
list or None: The collected results.
@@ -927,6 +928,9 @@ def collect_results(results: list,
927928
if device == 'gpu':
928929
assert tmpdir is None, 'tmpdir should be None when device is "gpu"'
929930
return collect_results_gpu(results, size)
931+
elif device == 'npu':
932+
assert tmpdir is None, 'tmpdir should be None when device is "npu"'
933+
return collect_results_npu(results, size)
930934
else:
931935
return collect_results_cpu(results, size, tmpdir)
932936

@@ -1068,6 +1072,56 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
10681072
return None
10691073

10701074

1075+
def collect_results_npu(result_part: list, size: int) -> Optional[list]:
1076+
"""Collect results under npu mode.
1077+
1078+
On npu mode, this function will encode results to npu tensors and use npu
1079+
communication for results collection.
1080+
1081+
Args:
1082+
result_part (list[object]): Result list containing result parts
1083+
to be collected. Each item of ``result_part`` should be a picklable
1084+
object.
1085+
size (int): Size of the results, commonly equal to length of
1086+
the results.
1087+
1088+
Returns:
1089+
list or None: The collected results.
1090+
1091+
Examples:
1092+
>>> # distributed environment
1093+
>>> # We have 2 process groups, 2 ranks.
1094+
>>> import mmengine.dist as dist
1095+
>>> if dist.get_rank() == 0:
1096+
data = ['foo', {1: 2}]
1097+
else:
1098+
data = [24, {'a': 'b'}]
1099+
>>> size = 4
1100+
>>> output = dist.collect_results_npu(data, size)
1101+
>>> output
1102+
['foo', 24, {1: 2}, {'a': 'b'}] # rank 0
1103+
None # rank 1
1104+
"""
1105+
rank, world_size = get_dist_info()
1106+
if world_size == 1:
1107+
return result_part[:size]
1108+
1109+
# gather all result part. Note that NCCL does not support gather so use
1110+
# all_gather_object instead.
1111+
part_list = all_gather_object(result_part)
1112+
1113+
if rank == 0:
1114+
# sort the results
1115+
ordered_results = []
1116+
for res in zip(*part_list):
1117+
ordered_results.extend(list(res))
1118+
# the dataloader may pad some samples
1119+
ordered_results = ordered_results[:size]
1120+
return ordered_results
1121+
else:
1122+
return None
1123+
1124+
10711125
def _all_reduce_coalesced(tensors: List[torch.Tensor],
10721126
bucket_size_mb: int = -1,
10731127
op: str = 'sum',

0 commit comments

Comments
 (0)