|
41 | 41 | import vllm.envs as envs
|
42 | 42 | from vllm.logger import init_logger
|
43 | 43 | from vllm.platforms import current_platform
|
| 44 | +from vllm.utils import supports_custom_op |
44 | 45 |
|
45 | 46 |
|
46 | 47 | @dataclass
|
@@ -100,32 +101,33 @@ def _register_group(group: "GroupCoordinator") -> None:
|
100 | 101 | _groups[group.unique_name] = weakref.ref(group) # type: ignore
|
101 | 102 |
|
102 | 103 |
|
103 |
| -@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) |
104 |
| -def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: |
105 |
| - assert group_name in _groups, f"Group {group_name} is not found." |
106 |
| - group = _groups[group_name]() |
107 |
| - if group is None: |
108 |
| - raise ValueError(f"Group {group_name} is destroyed.") |
109 |
| - group._all_reduce(tensor) |
| 104 | +if supports_custom_op(): |
110 | 105 |
|
| 106 | + @torch.library.custom_op("vllm::inplace_all_reduce", |
| 107 | + mutates_args=["tensor"]) |
| 108 | + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: |
| 109 | + assert group_name in _groups, f"Group {group_name} is not found." |
| 110 | + group = _groups[group_name]() |
| 111 | + if group is None: |
| 112 | + raise ValueError(f"Group {group_name} is destroyed.") |
| 113 | + group._all_reduce(tensor) |
111 | 114 |
|
112 |
| -@inplace_all_reduce.register_fake |
113 |
| -def _(tensor: torch.Tensor, group_name: str) -> None: |
114 |
| - return |
115 |
| - |
116 |
| - |
117 |
| -@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) |
118 |
| -def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: |
119 |
| - assert group_name in _groups, f"Group {group_name} is not found." |
120 |
| - group = _groups[group_name]() |
121 |
| - if group is None: |
122 |
| - raise ValueError(f"Group {group_name} is destroyed.") |
123 |
| - return group._all_reduce(tensor) |
| 115 | + @inplace_all_reduce.register_fake |
| 116 | + def _(tensor: torch.Tensor, group_name: str) -> None: |
| 117 | + return |
124 | 118 |
|
| 119 | + @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) |
| 120 | + def outplace_all_reduce(tensor: torch.Tensor, |
| 121 | + group_name: str) -> torch.Tensor: |
| 122 | + assert group_name in _groups, f"Group {group_name} is not found." |
| 123 | + group = _groups[group_name]() |
| 124 | + if group is None: |
| 125 | + raise ValueError(f"Group {group_name} is destroyed.") |
| 126 | + return group._all_reduce(tensor) |
125 | 127 |
|
126 |
| -@outplace_all_reduce.register_fake |
127 |
| -def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: |
128 |
| - return torch.empty_like(tensor) |
| 128 | + @outplace_all_reduce.register_fake |
| 129 | + def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: |
| 130 | + return torch.empty_like(tensor) |
129 | 131 |
|
130 | 132 |
|
131 | 133 | class GroupCoordinator:
|
@@ -340,6 +342,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
340 | 342 | if self.world_size == 1:
|
341 | 343 | return input_
|
342 | 344 |
|
| 345 | + if not supports_custom_op(): |
| 346 | + return self._all_reduce(input_) |
| 347 | + |
343 | 348 | if self.tpu_communicator is not None and \
|
344 | 349 | not self.tpu_communicator.disabled:
|
345 | 350 | # TPU handles Dynamo with its own logic.
|
|
0 commit comments