@@ -898,10 +898,11 @@ def collect_results(results: list,
898
898
object.
899
899
size (int): Size of the results, commonly equal to length of
900
900
the results.
901
- device (str): Device name. Optional values are 'cpu' and 'gpu'.
901
+ device (str): Device name. Optional values are 'cpu', 'gpu' or 'npu '.
902
902
tmpdir (str | None): Temporal directory for collected results to
903
903
store. If set to None, it will create a temporal directory for it.
904
- ``tmpdir`` should be None when device is 'gpu'. Defaults to None.
904
+ ``tmpdir`` should be None when device is 'gpu' or 'npu'.
905
+ Defaults to None.
905
906
906
907
Returns:
907
908
list or None: The collected results.
@@ -920,13 +921,13 @@ def collect_results(results: list,
920
921
['foo', 24, {1: 2}, {'a': 'b'}] # rank 0
921
922
None # rank 1
922
923
"""
923
- if device not in ['gpu' , 'cpu' ]:
924
+ if device not in ['gpu' , 'cpu' , 'npu' ]:
924
925
raise NotImplementedError (
925
- f"device must be 'cpu' or 'gpu', but got { device } " )
926
+ f"device must be 'cpu' , 'gpu' or 'npu ', but got { device } " )
926
927
927
- if device == 'gpu' :
928
- assert tmpdir is None , 'tmpdir should be None when device is "gpu" '
929
- return collect_results_gpu (results , size )
928
+ if device == 'gpu' or device == 'npu' :
929
+ assert tmpdir is None , f 'tmpdir should be None when device is { device } '
930
+ return _collect_results_device (results , size )
930
931
else :
931
932
return collect_results_cpu (results , size , tmpdir )
932
933
@@ -1018,6 +1019,28 @@ def collect_results_cpu(result_part: list,
1018
1019
return ordered_results
1019
1020
1020
1021
1022
+ def _collect_results_device (result_part : list , size : int ) -> Optional [list ]:
1023
+ """Collect results under gpu or npu mode."""
1024
+ rank , world_size = get_dist_info ()
1025
+ if world_size == 1 :
1026
+ return result_part [:size ]
1027
+
1028
+ # gather all result part. Note that NCCL does not support gather so use
1029
+ # all_gather_object instead.
1030
+ part_list = all_gather_object (result_part )
1031
+
1032
+ if rank == 0 :
1033
+ # sort the results
1034
+ ordered_results = []
1035
+ for res in zip (* part_list ):
1036
+ ordered_results .extend (list (res ))
1037
+ # the dataloader may pad some samples
1038
+ ordered_results = ordered_results [:size ]
1039
+ return ordered_results
1040
+ else :
1041
+ return None
1042
+
1043
+
1021
1044
def collect_results_gpu (result_part : list , size : int ) -> Optional [list ]:
1022
1045
"""Collect results under gpu mode.
1023
1046
@@ -1048,24 +1071,7 @@ def collect_results_gpu(result_part: list, size: int) -> Optional[list]:
1048
1071
['foo', 24, {1: 2}, {'a': 'b'}] # rank 0
1049
1072
None # rank 1
1050
1073
"""
1051
- rank , world_size = get_dist_info ()
1052
- if world_size == 1 :
1053
- return result_part [:size ]
1054
-
1055
- # gather all result part. Note that NCCL does not support gather so use
1056
- # all_gather_object instead.
1057
- part_list = all_gather_object (result_part )
1058
-
1059
- if rank == 0 :
1060
- # sort the results
1061
- ordered_results = []
1062
- for res in zip (* part_list ):
1063
- ordered_results .extend (list (res ))
1064
- # the dataloader may pad some samples
1065
- ordered_results = ordered_results [:size ]
1066
- return ordered_results
1067
- else :
1068
- return None
1074
+ return _collect_results_device (result_part , size )
1069
1075
1070
1076
1071
1077
def _all_reduce_coalesced (tensors : List [torch .Tensor ],
0 commit comments