Skip to content

Commit a00fe6c

Browse files
lsy323CAROLZXYZXYjianzshfan
authored andcommitted
[Hardware][TPU] Optionally import for TPU backend (vllm-project#18269)
Signed-off-by: Siyuan Liu <[email protected]> Signed-off-by: Jade Zheng <[email protected]> Co-authored-by: Carol Zheng <[email protected]> Co-authored-by: Jade Zheng <[email protected]> Co-authored-by: Hongmin Fan <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent 2135bef commit a00fe6c

File tree

3 files changed

+25
-0
lines changed

3 files changed

+25
-0
lines changed

vllm/distributed/device_communicators/tpu_communicator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
9191
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
9292
assert dim == -1, "TPUs only support dim=-1 for all-gather."
9393
return xm.all_gather(input_, dim=dim)
94+
95+
96+
try:
97+
from tpu_commons.distributed.device_communicators import (
98+
TpuCommunicator as TpuCommonsCommunicator)
99+
TpuCommunicator = TpuCommonsCommunicator # type: ignore
100+
except ImportError:
101+
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
102+
pass

vllm/platforms/tpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,11 @@ def validate_request(
194194
if params.sampling_type == SamplingType.RANDOM_SEED:
195195
raise ValueError(
196196
"Torch XLA does not support per-request seed.")
197+
198+
199+
try:
200+
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
201+
TpuPlatform = TpuCommonsPlatform # type: ignore
202+
except ImportError:
203+
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
204+
pass

vllm/v1/worker/tpu_worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,11 @@ def init_tpu_worker_distributed_environment(
267267
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
268268
parallel_config.pipeline_parallel_size,
269269
parallel_config.enable_expert_parallel)
270+
271+
272+
try:
273+
from tpu_commons.worker import TPUWorker as TPUCommonsWorker
274+
TPUWorker = TPUCommonsWorker # type: ignore
275+
except ImportError:
276+
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
277+
pass

0 commit comments

Comments
 (0)