1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
- from functools import cache
5
- from typing import Optional , Union
4
+ from typing import Callable , Optional , Union
6
5
7
6
import torch
8
7
from torch .distributed import ProcessGroup
16
15
logger = init_logger (__name__ )
17
16
18
17
19
- @cache
20
- def is_rocm_aiter_custom_allreduce_enabled () -> bool :
21
- """Check if aiter custom allreduce is enabled for ROCm platform."""
22
- return current_platform .is_rocm () \
23
- and envs .VLLM_ROCM_USE_AITER \
24
- and envs .VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE
18
+ # ROCm allreduce dispatcher that dispatches the
19
+ # performant allreduce implementation based on
20
+ # the available implementations and payload size
21
+ # of input tensor. It only supports AMD ROCm platform.
22
+ class ROCmAllreduceDispatcher :
23
+ def __init__ (self ,
24
+ group : ProcessGroup ,
25
+ device : Union [int , str , torch .device ],
26
+ ca_comm = None ,
27
+ pynccl_comm = None ):
28
+ self .process_group = group
29
+ self .device = device
30
+ self .cur_device_arch = self ._get_current_device_arch ()
31
+
32
+ self .tp_size = torch .distributed .get_world_size (group = self .process_group )
33
+
34
+ # include the MI300, MI308, MI350, MI355
35
+ self .supported_device_archs = ["MI30X" , "MI35X" ]
36
+
37
+ # dispatch thresholds by tp_size:
38
+ # (aiter_ca_threshold(KB), vllm_qr_threshold)
39
+ self .MI35x_thresholds = {
40
+ 2 : 1024 ,
41
+ 4 : 4096 ,
42
+ 8 : 8192 ,
43
+ }
44
+ self .MI30x_thresholds = {
45
+ 2 : 1024 ,
46
+ 4 : 2048 ,
47
+ 8 : 2048 ,
48
+ }
49
+
50
+ # allreduce naming : associated allreduce impl, allreduce check impl
51
+ self .available_allreduce_impls : dict [str , (Callable , Optional [Callable ])] = {}
52
+
53
+ self .available_allreduce_impls ["pynccl" ] = (pynccl_comm .all_reduce , None )
54
+ self .fallback_impl = pynccl_comm .all_reduce
55
+
56
+ if ca_comm is not None :
57
+ self .available_allreduce_impls ["vllm_ca" ] = \
58
+ (ca_comm .custom_all_reduce , ca_comm .should_custom_ar )
59
+
60
+ # Initialize a custom quick all-reduce implementation for AMD.
61
+ # Quick reduce is designed as a complement to custom allreduce.
62
+ # Based on quickreduce (https://github.com/mk1-project/quickreduce).
63
+ # If it's a rocm, 'use_custom_allreduce==True' means it must
64
+ # currently be an MI300 series.
65
+ from vllm .distributed .device_communicators .quick_all_reduce import (
66
+ QuickAllReduce )
67
+ self .qr_comm = QuickAllReduce (group = self .process_group ,
68
+ device = self .device )
69
+ if self .qr_comm is not None :
70
+ self .available_allreduce_impls ["vllm_qr" ] = \
71
+ (self .qr_comm .quick_all_reduce , self .qr_comm .should_quick_allreduce )
72
+
73
+ # Initialize a custom all-reduce implementation from aiter.
74
+ if self ._is_aiter_custom_allreduce_available ():
75
+ from aiter .dist .custom_all_reduce import CustomAllreduce \
76
+ as AiterCustomAllreduce
77
+ self .aiter_ca_comm = AiterCustomAllreduce (
78
+ group = self .process_group ,
79
+ device = self .device ,
80
+ )
81
+ if self .aiter_ca_comm is not None :
82
+ self .available_allreduce_impls ["aiter_ca" ] = \
83
+ (self .aiter_ca_comm .custom_all_reduce , \
84
+ self .aiter_ca_comm .should_custom_ar )
85
+
86
+ def _is_aiter_custom_allreduce_available (self ) -> bool :
87
+ """Check if aiter is enabled for ROCm platform."""
88
+ if not envs .VLLM_ROCM_USE_AITER :
89
+ return False
90
+
91
+ try :
92
+ from aiter .dist .custom_all_reduce import CustomAllreduce
93
+ return True
94
+ except ImportError :
95
+ return False
96
+
97
+ def _get_current_device_arch (self ) -> str :
98
+ """Get the device micro architecture number of the current device."""
99
+ # TODO(zejun): Add more device architectures
100
+ device_arch = torch .cuda .get_device_properties ("cuda" ).gcnArchName
101
+ if "gfx95" in device_arch :
102
+ return "MI35X"
103
+ elif "gfx94" in device_arch :
104
+ return "MI30X"
105
+ elif "gfx11" in device_arch :
106
+ return "RX"
107
+ else :
108
+ return device_arch
109
+
110
+ def _should_allreduce (self , input_ : torch .Tensor , impl_name : str ) -> bool :
111
+ if impl_name not in self .available_allreduce_impls :
112
+ return False
113
+ return self .available_allreduce_impls [impl_name ][1 ](input_ )
114
+
115
+ def _dispatch_mi30x (self ,
116
+ input_ : torch .Tensor ,
117
+ payload_size_KB : int ,
118
+ tp_size : int ) -> Callable :
119
+ """Dispatch implementation for MI35X architecture."""
120
+ if tp_size not in self .MI30x_thresholds :
121
+ return self .fallback_impl
122
+
123
+ threshold = self .MI30x_thresholds [tp_size ]
124
+
125
+ if payload_size_KB <= threshold and \
126
+ self ._should_allreduce (input_ , "vllm_ca" ):
127
+ return self .available_allreduce_impls ["vllm_ca" ][0 ]
128
+
129
+ if self ._should_allreduce (input_ , "vllm_qr" ):
130
+ return self .available_allreduce_impls ["vllm_qr" ][0 ]
25
131
132
+ return self .fallback_impl
133
+
134
+ def _dispatch_mi35x (self ,
135
+ input_ : torch .Tensor ,
136
+ payload_size_KB : int ,
137
+ tp_size : int ) -> Callable :
138
+ """Dispatch implementation for MI35X architecture."""
139
+ if tp_size not in self .MI35x_thresholds :
140
+ return self .fallback_impl
141
+
142
+ threshold = self .MI35x_thresholds [tp_size ]
143
+
144
+ if payload_size_KB <= threshold and \
145
+ self ._should_allreduce (input_ , "aiter_ca" ):
146
+ return self .available_allreduce_impls ["aiter_ca" ][0 ]
147
+
148
+ if self ._should_allreduce (input_ , "vllm_qr" ):
149
+ return self .available_allreduce_impls ["vllm_qr" ][0 ]
150
+
151
+ return self .fallback_impl
152
+
153
+ def _dispatch_impl (self ,
154
+ input_ : torch .Tensor ,
155
+ payload_size_KB : int ,
156
+ device_arch : str ,
157
+ tp_size : int ) -> Callable :
158
+ if device_arch not in self .supported_device_archs :
159
+ logger .debug (f"Device architecture { device_arch } not supported, using pynccl" )
160
+ return self .fallback_impl
161
+
162
+ if device_arch == "MI35X" :
163
+ return self ._dispatch_mi35x (input_ , payload_size_KB , tp_size )
164
+ elif device_arch == "MI30X" :
165
+ return self ._dispatch_mi30x (input_ , payload_size_KB , tp_size )
166
+ else :
167
+ # for other devices, fallback to pynccl
168
+ return self .fallback_impl
169
+
170
+ def dispatch (self , input_ : torch .Tensor ) -> Callable :
171
+ """Dispatch the allreduce implementation"""
172
+ # unit: KB
173
+ payload_size = int (input_ .numel () * input_ .element_size () / 1024.0 )
174
+ op = self ._dispatch_impl (input_ ,
175
+ payload_size ,
176
+ self .cur_device_arch ,
177
+ self .tp_size )
178
+ return op
26
179
27
180
class CudaCommunicator (DeviceCommunicatorBase ):
28
181
@@ -55,8 +208,8 @@ def __init__(self,
55
208
CustomAllreduce )
56
209
from vllm .distributed .device_communicators .pynccl import (
57
210
PyNcclCommunicator )
58
- from vllm .distributed .device_communicators .quick_all_reduce import (
59
- QuickAllReduce )
211
+ from vllm .distributed .device_communicators .symm_mem import (
212
+ SymmMemCommunicator )
60
213
61
214
self .pynccl_comm : Optional [PyNcclCommunicator ] = None
62
215
if use_pynccl and self .world_size > 1 :
@@ -66,7 +219,11 @@ def __init__(self,
66
219
)
67
220
68
221
self .ca_comm : Optional [CustomAllreduce ] = None
69
- self .qr_comm : Optional [QuickAllReduce ] = None
222
+ self .symm_mem_comm : Optional [SymmMemCommunicator ] = None
223
+
224
+ # Initialize a custom all-reduce dispatcher for ROCm platform
225
+ self .rocm_allreduce_dispatcher : Optional [ROCmAllreduceDispatcher ] = None
226
+
70
227
if use_custom_allreduce and self .world_size > 1 :
71
228
# Initialize a custom fast all-reduce implementation.
72
229
self .ca_comm = CustomAllreduce (
@@ -75,13 +232,19 @@ def __init__(self,
75
232
)
76
233
77
234
if current_platform .is_rocm ():
78
- # Initialize a custom quick all-reduce implementation for AMD.
79
- # Quick reduce is designed as a complement to custom allreduce.
80
- # Based on quickreduce (https://github.com/mk1-project/quickreduce).
81
- # If it's a rocm, 'use_custom_allreduce==True' means it must
82
- # currently be an MI300 series.
83
- self .qr_comm = QuickAllReduce (group = self .cpu_group ,
84
- device = self .device )
235
+ self .rocm_allreduce_dispatcher = \
236
+ ROCmAllreduceDispatcher (group = self .cpu_group ,
237
+ device = self .device ,
238
+ ca_comm = self .ca_comm ,
239
+ pynccl_comm = self .pynccl_comm )
240
+ logger .info ("Initializing ROCm allreduce dispatcher." )
241
+
242
+ if envs .VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform .is_cuda ():
243
+ self .symm_mem_comm = SymmMemCommunicator (
244
+ group = self .cpu_group ,
245
+ device = self .device ,
246
+ )
247
+
85
248
if self .use_all2all :
86
249
all2all_backend = envs .VLLM_ALL2ALL_BACKEND
87
250
if all2all_backend == "naive" :
@@ -104,23 +267,29 @@ def __init__(self,
104
267
raise ValueError (f"Unknown all2all backend: { all2all_backend } " )
105
268
106
269
def all_reduce (self , input_ ):
107
- # always try quick reduce first, then custom allreduce,
108
- # and then pynccl. (quick reduce just for ROCM MI3*)
109
- qr_comm = self .qr_comm
110
- if qr_comm is not None and not qr_comm .disabled and \
111
- qr_comm .should_quick_allreduce (input_ ):
112
- out = qr_comm .quick_all_reduce (input_ )
113
- assert out is not None
114
- return out
115
- ca_comm = self .ca_comm
116
- if ca_comm is not None and not ca_comm .disabled and \
117
- ca_comm .should_custom_ar (input_ ):
118
- out = ca_comm .custom_all_reduce (input_ )
119
- assert out is not None
120
- return out
121
- pynccl_comm = self .pynccl_comm
122
- assert pynccl_comm is not None
123
- out = pynccl_comm .all_reduce (input_ )
270
+ if current_platform .is_rocm () and self .rocm_allreduce_dispatcher is not None :
271
+ op = self .rocm_allreduce_dispatcher .dispatch (input_ )
272
+ logger .debug (f"ROCm allreduce dispatcher dispatched: { op } " )
273
+ out = op (input_ )
274
+ else :
275
+ ca_comm = self .ca_comm
276
+ if ca_comm is not None and not ca_comm .disabled and \
277
+ ca_comm .should_custom_ar (input_ ):
278
+ out = ca_comm .custom_all_reduce (input_ )
279
+ assert out is not None
280
+ return out
281
+ symm_mem_comm = self .symm_mem_comm
282
+ if symm_mem_comm is not None and \
283
+ symm_mem_comm .should_use_symm_mem (input_ ):
284
+ out = symm_mem_comm .all_reduce (input_ )
285
+ assert out is not None
286
+ return out
287
+
288
+ pynccl_comm = self .pynccl_comm
289
+ assert pynccl_comm is not None
290
+ out = pynccl_comm .all_reduce (input_ )
291
+
292
+ # fallback to the default all-reduce using PyTorch.
124
293
if out is None :
125
294
# fall back to the default all-reduce using PyTorch.
126
295
# this usually happens during testing.
0 commit comments