Skip to content

Commit d37b9e1

Browse files
jikunshangyoukaichao
authored andcommitted
[MISC] add support custom_op check (vllm-project#8557)
Co-authored-by: youkaichao <[email protected]>
1 parent 31ce761 commit d37b9e1

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

vllm/distributed/parallel_state.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import vllm.envs as envs
4242
from vllm.logger import init_logger
4343
from vllm.platforms import current_platform
44+
from vllm.utils import supports_custom_op
4445

4546

4647
@dataclass
@@ -100,32 +101,33 @@ def _register_group(group: "GroupCoordinator") -> None:
100101
_groups[group.unique_name] = weakref.ref(group) # type: ignore
101102

102103

103-
@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"])
104-
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
105-
assert group_name in _groups, f"Group {group_name} is not found."
106-
group = _groups[group_name]()
107-
if group is None:
108-
raise ValueError(f"Group {group_name} is destroyed.")
109-
group._all_reduce(tensor)
104+
if supports_custom_op():
110105

106+
@torch.library.custom_op("vllm::inplace_all_reduce",
107+
mutates_args=["tensor"])
108+
def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
109+
assert group_name in _groups, f"Group {group_name} is not found."
110+
group = _groups[group_name]()
111+
if group is None:
112+
raise ValueError(f"Group {group_name} is destroyed.")
113+
group._all_reduce(tensor)
111114

112-
@inplace_all_reduce.register_fake
113-
def _(tensor: torch.Tensor, group_name: str) -> None:
114-
return
115-
116-
117-
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
118-
def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
119-
assert group_name in _groups, f"Group {group_name} is not found."
120-
group = _groups[group_name]()
121-
if group is None:
122-
raise ValueError(f"Group {group_name} is destroyed.")
123-
return group._all_reduce(tensor)
115+
@inplace_all_reduce.register_fake
116+
def _(tensor: torch.Tensor, group_name: str) -> None:
117+
return
124118

119+
@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[])
120+
def outplace_all_reduce(tensor: torch.Tensor,
121+
group_name: str) -> torch.Tensor:
122+
assert group_name in _groups, f"Group {group_name} is not found."
123+
group = _groups[group_name]()
124+
if group is None:
125+
raise ValueError(f"Group {group_name} is destroyed.")
126+
return group._all_reduce(tensor)
125127

126-
@outplace_all_reduce.register_fake
127-
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
128-
return torch.empty_like(tensor)
128+
@outplace_all_reduce.register_fake
129+
def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
130+
return torch.empty_like(tensor)
129131

130132

131133
class GroupCoordinator:
@@ -340,6 +342,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
340342
if self.world_size == 1:
341343
return input_
342344

345+
if not supports_custom_op():
346+
return self._all_reduce(input_)
347+
343348
if self.tpu_communicator is not None and \
344349
not self.tpu_communicator.disabled:
345350
# TPU handles Dynamo with its own logic.

vllm/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,12 @@ def supports_dynamo() -> bool:
12521252
return base_torch_version >= Version("2.4.0")
12531253

12541254

1255+
# Some backends use pytorch version < 2.4.0 which doesn't
1256+
# support `torch.library.custom_op`.
1257+
def supports_custom_op() -> bool:
1258+
return hasattr(torch.library, "custom_op")
1259+
1260+
12551261
class AtomicCounter:
12561262
"""An atomic, thread-safe counter"""
12571263

0 commit comments

Comments
 (0)