Skip to content

Commit 948b848

Browse files
[ROCm Allreduce dispatcher] Add allreduce dispatcher
for ROCm device Signed-off-by: zejunchen-zejun <[email protected]>
1 parent 294874d commit 948b848

File tree

1 file changed

+204
-35
lines changed

1 file changed

+204
-35
lines changed

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 204 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from functools import cache
5-
from typing import Optional, Union
4+
from typing import Callable, Optional, Union
65

76
import torch
87
from torch.distributed import ProcessGroup
@@ -16,13 +15,167 @@
1615
logger = init_logger(__name__)
1716

1817

19-
@cache
20-
def is_rocm_aiter_custom_allreduce_enabled() -> bool:
21-
"""Check if aiter custom allreduce is enabled for ROCm platform."""
22-
return current_platform.is_rocm() \
23-
and envs.VLLM_ROCM_USE_AITER \
24-
and envs.VLLM_ROCM_USE_AITER_CUSTOM_ALL_REDUCE
18+
# ROCm allreduce dispatcher that dispatches the
19+
# performant allreduce implementation based on
20+
# the available implementations and payload size
21+
# of input tensor. It only supports AMD ROCm platform.
22+
class ROCmAllreduceDispatcher:
23+
def __init__(self,
24+
group: ProcessGroup,
25+
device: Union[int, str, torch.device],
26+
ca_comm = None,
27+
pynccl_comm = None):
28+
self.process_group = group
29+
self.device = device
30+
self.cur_device_arch = self._get_current_device_arch()
31+
32+
self.tp_size = torch.distributed.get_world_size(group=self.process_group)
33+
34+
# include the MI300, MI308, MI350, MI355
35+
self.supported_device_archs = ["MI30X", "MI35X"]
36+
37+
# dispatch thresholds by tp_size:
38+
# (aiter_ca_threshold(KB), vllm_qr_threshold)
39+
self.MI35x_thresholds = {
40+
2: 1024,
41+
4: 4096,
42+
8: 8192,
43+
}
44+
self.MI30x_thresholds = {
45+
2: 1024,
46+
4: 2048,
47+
8: 2048,
48+
}
49+
50+
# allreduce naming : associated allreduce impl, allreduce check impl
51+
self.available_allreduce_impls: dict[str, (Callable, Optional[Callable])] = {}
52+
53+
self.available_allreduce_impls["pynccl"] = (pynccl_comm.all_reduce, None)
54+
self.fallback_impl = pynccl_comm.all_reduce
55+
56+
if ca_comm is not None:
57+
self.available_allreduce_impls["vllm_ca"] = \
58+
(ca_comm.custom_all_reduce, ca_comm.should_custom_ar)
59+
60+
# Initialize a custom quick all-reduce implementation for AMD.
61+
# Quick reduce is designed as a complement to custom allreduce.
62+
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
63+
# If it's a rocm, 'use_custom_allreduce==True' means it must
64+
# currently be an MI300 series.
65+
from vllm.distributed.device_communicators.quick_all_reduce import (
66+
QuickAllReduce)
67+
self.qr_comm = QuickAllReduce(group=self.process_group,
68+
device=self.device)
69+
if self.qr_comm is not None:
70+
self.available_allreduce_impls["vllm_qr"] = \
71+
(self.qr_comm.quick_all_reduce, self.qr_comm.should_quick_allreduce)
72+
73+
# Initialize a custom all-reduce implementation from aiter.
74+
if self._is_aiter_custom_allreduce_available():
75+
from aiter.dist.custom_all_reduce import CustomAllreduce \
76+
as AiterCustomAllreduce
77+
self.aiter_ca_comm = AiterCustomAllreduce(
78+
group=self.process_group,
79+
device=self.device,
80+
)
81+
if self.aiter_ca_comm is not None:
82+
self.available_allreduce_impls["aiter_ca"] = \
83+
(self.aiter_ca_comm.custom_all_reduce, \
84+
self.aiter_ca_comm.should_custom_ar)
85+
86+
def _is_aiter_custom_allreduce_available(self) -> bool:
87+
"""Check if aiter is enabled for ROCm platform."""
88+
if not envs.VLLM_ROCM_USE_AITER:
89+
return False
90+
91+
try:
92+
from aiter.dist.custom_all_reduce import CustomAllreduce
93+
return True
94+
except ImportError:
95+
return False
96+
97+
def _get_current_device_arch(self) -> str:
98+
"""Get the device micro architecture number of the current device."""
99+
# TODO(zejun): Add more device architectures
100+
device_arch = torch.cuda.get_device_properties("cuda").gcnArchName
101+
if "gfx95" in device_arch:
102+
return "MI35X"
103+
elif "gfx94" in device_arch:
104+
return "MI30X"
105+
elif "gfx11" in device_arch:
106+
return "RX"
107+
else:
108+
return device_arch
109+
110+
def _should_allreduce(self, input_: torch.Tensor, impl_name: str) -> bool:
111+
if impl_name not in self.available_allreduce_impls:
112+
return False
113+
return self.available_allreduce_impls[impl_name][1](input_)
114+
115+
def _dispatch_mi30x(self,
116+
input_: torch.Tensor,
117+
payload_size_KB: int,
118+
tp_size: int) -> Callable:
119+
"""Dispatch implementation for MI35X architecture."""
120+
if tp_size not in self.MI30x_thresholds:
121+
return self.fallback_impl
122+
123+
threshold = self.MI30x_thresholds[tp_size]
124+
125+
if payload_size_KB <= threshold and \
126+
self._should_allreduce(input_, "vllm_ca"):
127+
return self.available_allreduce_impls["vllm_ca"][0]
128+
129+
if self._should_allreduce(input_, "vllm_qr"):
130+
return self.available_allreduce_impls["vllm_qr"][0]
25131

132+
return self.fallback_impl
133+
134+
def _dispatch_mi35x(self,
135+
input_: torch.Tensor,
136+
payload_size_KB: int,
137+
tp_size: int) -> Callable:
138+
"""Dispatch implementation for MI35X architecture."""
139+
if tp_size not in self.MI35x_thresholds:
140+
return self.fallback_impl
141+
142+
threshold = self.MI35x_thresholds[tp_size]
143+
144+
if payload_size_KB <= threshold and \
145+
self._should_allreduce(input_, "aiter_ca"):
146+
return self.available_allreduce_impls["aiter_ca"][0]
147+
148+
if self._should_allreduce(input_, "vllm_qr"):
149+
return self.available_allreduce_impls["vllm_qr"][0]
150+
151+
return self.fallback_impl
152+
153+
def _dispatch_impl(self,
154+
input_: torch.Tensor,
155+
payload_size_KB: int,
156+
device_arch: str,
157+
tp_size: int) -> Callable:
158+
if device_arch not in self.supported_device_archs:
159+
logger.debug(f"Device architecture {device_arch} not supported, using pynccl")
160+
return self.fallback_impl
161+
162+
if device_arch == "MI35X":
163+
return self._dispatch_mi35x(input_, payload_size_KB, tp_size)
164+
elif device_arch == "MI30X":
165+
return self._dispatch_mi30x(input_, payload_size_KB, tp_size)
166+
else:
167+
# for other devices, fallback to pynccl
168+
return self.fallback_impl
169+
170+
def dispatch(self, input_: torch.Tensor) -> Callable:
171+
"""Dispatch the allreduce implementation"""
172+
# unit: KB
173+
payload_size = int(input_.numel() * input_.element_size() / 1024.0)
174+
op = self._dispatch_impl(input_,
175+
payload_size,
176+
self.cur_device_arch,
177+
self.tp_size)
178+
return op
26179

27180
class CudaCommunicator(DeviceCommunicatorBase):
28181

@@ -55,8 +208,8 @@ def __init__(self,
55208
CustomAllreduce)
56209
from vllm.distributed.device_communicators.pynccl import (
57210
PyNcclCommunicator)
58-
from vllm.distributed.device_communicators.quick_all_reduce import (
59-
QuickAllReduce)
211+
from vllm.distributed.device_communicators.symm_mem import (
212+
SymmMemCommunicator)
60213

61214
self.pynccl_comm: Optional[PyNcclCommunicator] = None
62215
if use_pynccl and self.world_size > 1:
@@ -66,7 +219,11 @@ def __init__(self,
66219
)
67220

68221
self.ca_comm: Optional[CustomAllreduce] = None
69-
self.qr_comm: Optional[QuickAllReduce] = None
222+
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
223+
224+
# Initialize a custom all-reduce dispatcher for ROCm platform
225+
self.rocm_allreduce_dispatcher: Optional[ROCmAllreduceDispatcher] = None
226+
70227
if use_custom_allreduce and self.world_size > 1:
71228
# Initialize a custom fast all-reduce implementation.
72229
self.ca_comm = CustomAllreduce(
@@ -75,13 +232,19 @@ def __init__(self,
75232
)
76233

77234
if current_platform.is_rocm():
78-
# Initialize a custom quick all-reduce implementation for AMD.
79-
# Quick reduce is designed as a complement to custom allreduce.
80-
# Based on quickreduce (https://github.com/mk1-project/quickreduce).
81-
# If it's a rocm, 'use_custom_allreduce==True' means it must
82-
# currently be an MI300 series.
83-
self.qr_comm = QuickAllReduce(group=self.cpu_group,
84-
device=self.device)
235+
self.rocm_allreduce_dispatcher = \
236+
ROCmAllreduceDispatcher(group=self.cpu_group,
237+
device=self.device,
238+
ca_comm=self.ca_comm,
239+
pynccl_comm=self.pynccl_comm)
240+
logger.info("Initializing ROCm allreduce dispatcher.")
241+
242+
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
243+
self.symm_mem_comm = SymmMemCommunicator(
244+
group=self.cpu_group,
245+
device=self.device,
246+
)
247+
85248
if self.use_all2all:
86249
all2all_backend = envs.VLLM_ALL2ALL_BACKEND
87250
if all2all_backend == "naive":
@@ -104,23 +267,29 @@ def __init__(self,
104267
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
105268

106269
def all_reduce(self, input_):
107-
# always try quick reduce first, then custom allreduce,
108-
# and then pynccl. (quick reduce just for ROCM MI3*)
109-
qr_comm = self.qr_comm
110-
if qr_comm is not None and not qr_comm.disabled and \
111-
qr_comm.should_quick_allreduce(input_):
112-
out = qr_comm.quick_all_reduce(input_)
113-
assert out is not None
114-
return out
115-
ca_comm = self.ca_comm
116-
if ca_comm is not None and not ca_comm.disabled and \
117-
ca_comm.should_custom_ar(input_):
118-
out = ca_comm.custom_all_reduce(input_)
119-
assert out is not None
120-
return out
121-
pynccl_comm = self.pynccl_comm
122-
assert pynccl_comm is not None
123-
out = pynccl_comm.all_reduce(input_)
270+
if current_platform.is_rocm() and self.rocm_allreduce_dispatcher is not None:
271+
op = self.rocm_allreduce_dispatcher.dispatch(input_)
272+
logger.debug(f"ROCm allreduce dispatcher dispatched: {op}")
273+
out = op(input_)
274+
else:
275+
ca_comm = self.ca_comm
276+
if ca_comm is not None and not ca_comm.disabled and \
277+
ca_comm.should_custom_ar(input_):
278+
out = ca_comm.custom_all_reduce(input_)
279+
assert out is not None
280+
return out
281+
symm_mem_comm = self.symm_mem_comm
282+
if symm_mem_comm is not None and \
283+
symm_mem_comm.should_use_symm_mem(input_):
284+
out = symm_mem_comm.all_reduce(input_)
285+
assert out is not None
286+
return out
287+
288+
pynccl_comm = self.pynccl_comm
289+
assert pynccl_comm is not None
290+
out = pynccl_comm.all_reduce(input_)
291+
292+
# fallback to the default all-reduce using PyTorch.
124293
if out is None:
125294
# fall back to the default all-reduce using PyTorch.
126295
# this usually happens during testing.

0 commit comments

Comments
 (0)