Skip to content

[TPU] Align worker index with node boundary #7932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 2, 2024
Merged

[TPU] Align worker index with node boundary #7932

merged 4 commits into from
Sep 2, 2024

Conversation

WoosukKwon
Copy link
Collaborator

Fixes #7485

@WoosukKwon WoosukKwon added the tpu Related to Google TPUs label Aug 28, 2024
@WoosukKwon WoosukKwon requested a review from youkaichao August 28, 2024 00:25
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member

is it enough to remove the following lines?

with patch(
"vllm.model_executor.layers.vocab_parallel_embedding."
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):

@WoosukKwon
Copy link
Collaborator Author

@youkaichao I tried it, but still got gibberish results without the patch. I think this is because the rank IDs used in all gather are assigned by XLA runtime, regardless of IPs.

@youkaichao
Copy link
Member

I will not block this pr for that reason, but it would be better to know in the future, what's the rank used by XLA.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please fix the format

@Beomi
Copy link

Beomi commented Aug 28, 2024

Hi, I still get this error due to the mismatched global world size even on this branch(tpu-rank) version.

I use these commands below to build new docker image based on Dockerfile.tpu and use the code of this branch.

# on main node, main vm IP:PORT is 10.130.0.60:6379
sudo docker run -t -d -e HF_TOKEN=INPUT_YOUR_TOKEN --privileged --net host --shm-size=16G -it vllm ray start --head --block
# on other nodes
sudo docker run -t -d -e VLLM_HOST_IP=10.130.0.60 -e HF_TOKEN=YOUR_TOKEN -e  --privileged --net host --shm-size=16G -it vllm ray start --address 10.130.0.60:6379 --block

after checking this ray status -- to confirm all the tpu cores are gathered in ray(32 cores since I use TPUv4-64 for this exp)

======== Autoscaler status: 2024-08-28 04:31:11.442038 ========
Node status
---------------------------------------------------------------
Active:
 1 node_7f47f048ccc1c5455a203b92ef6d2bdd5d0332b18fa6f092d6c6425a
 1 node_be42b0ac8682eacb5da8cee5918de34e28dc90b1041b6522405c1c62
 1 node_13290553e86634224a524ec68c5b3e3f0478e491669681f92723ec49
 1 node_2ee5e9791b60dfaf727341ae8a67910996f0572ba36260df07507668
 1 node_3ad7ae73804a43f3a6acbc4e9b48f92693c8cf9e09e275e49863c056
 1 node_195573b8d1e45ef8e2e27d20cf530f48884e0be3e8ead3bf9c716338
 1 node_a1804b4b12788b7a33e02f0f9c38d18512d6ee4d27a9d9775fce7f5a
 1 node_13d0f73aeba95e5f03f42494a57a9ac1204e273106e7ad109dde11f5
Pending:
 (no pending nodes)
Recent failures:
 (no failures)

Resources
---------------------------------------------------------------
Usage:
 0.0/1920.0 CPU
 0.0/32.0 TPU
 0.0/1.0 TPU-v4-64-head
 0B/3.00TiB memory
 0B/121.60GiB object_store_memory
 0.0/8.0 v4-64

However when I run this command to run vllm server --

vllm serve NousResearch/Meta-Llama-3-8B-Instruct --dtype auto --api-key token-abc123 -tp 4 -pp 8 --device tpu --distributed_executor_backend ray

this error comes.

(pid=614, ip=10.130.15.194) INFO 08-28 04:16:28 importing.py:10] Triton not installed; certain GPU-related functions will not be available. [repeated 2x across cluster]
(RayWorkerWrapper pid=376, ip=10.130.0.62) WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
INFO 08-28 04:16:35 selector.py:198] Cannot use _Backend.FLASH_ATTN backend on TPU.
INFO 08-28 04:16:35 selector.py:146] Using Pallas backend.
(RayWorkerWrapper pid=376, ip=10.130.15.193) INFO 08-28 04:16:35 selector.py:198] Cannot use _Backend.FLASH_ATTN backend on TPU.
(RayWorkerWrapper pid=376, ip=10.130.15.193) INFO 08-28 04:16:35 selector.py:146] Using Pallas backend.
(pid=615, ip=10.130.0.59) INFO 08-28 04:16:32 importing.py:10] Triton not installed; certain GPU-related functions will not be available.
ERROR 08-28 04:16:35 worker_base.py:465] Error executing method init_device. This might cause deadlock in distributed execution.
ERROR 08-28 04:16:35 worker_base.py:465] Traceback (most recent call last):
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/worker/worker_base.py", line 457, in execute_method
ERROR 08-28 04:16:35 worker_base.py:465]     return executor(*args, **kwargs)
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/worker/tpu_worker.py", line 83, in init_device
ERROR 08-28 04:16:35 worker_base.py:465]     ensure_model_parallel_initialized(
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 965, in ensure_model_parallel_initialized
ERROR 08-28 04:16:35 worker_base.py:465]     initialize_model_parallel(tensor_model_parallel_size,
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 931, in initialize_model_parallel
ERROR 08-28 04:16:35 worker_base.py:465]     _TP = init_model_parallel_group(group_ranks,
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 773, in init_model_parallel_group
ERROR 08-28 04:16:35 worker_base.py:465]     return GroupCoordinator(
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/parallel_state.py", line 175, in __init__
ERROR 08-28 04:16:35 worker_base.py:465]     self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
ERROR 08-28 04:16:35 worker_base.py:465]   File "/workspace/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 29, in __init__
ERROR 08-28 04:16:35 worker_base.py:465]     local_rank = global_rank % local_world_size
ERROR 08-28 04:16:35 worker_base.py:465] ZeroDivisionError: integer division or modulo by zero
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/local/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/workspace/vllm/vllm/entrypoints/openai/rpc/server.py", line 230, in run_rpc_server
    server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
  File "/workspace/vllm/vllm/entrypoints/openai/rpc/server.py", line 31, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 709, in from_engine_args
    engine = cls(
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 600, in __init__
    self.engine = self._init_engine(*args, **kwargs)
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 809, in _init_engine
    return engine_class(*args, **kwargs)
  File "/workspace/vllm/vllm/engine/async_llm_engine.py", line 261, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/vllm/vllm/engine/llm_engine.py", line 288, in __init__
    self.model_executor = executor_class(
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 304, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 36, in __init__
    super().__init__(*args, **kwargs)
  File "/workspace/vllm/vllm/executor/executor_base.py", line 46, in __init__
    self._init_executor()
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 48, in _init_executor
    self._init_workers_ray(placement_group)
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 175, in _init_workers_ray
    self._run_workers("init_device")
  File "/workspace/vllm/vllm/executor/ray_tpu_executor.py", line 242, in _run_workers
    driver_worker_output = self.driver_worker.execute_method(
  File "/workspace/vllm/vllm/worker/worker_base.py", line 466, in execute_method
    raise e
  File "/workspace/vllm/vllm/worker/worker_base.py", line 457, in execute_method
    return executor(*args, **kwargs)
  File "/workspace/vllm/vllm/worker/tpu_worker.py", line 83, in init_device
    ensure_model_parallel_initialized(
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 965, in ensure_model_parallel_initialized
    initialize_model_parallel(tensor_model_parallel_size,
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 931, in initialize_model_parallel
    _TP = init_model_parallel_group(group_ranks,
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 773, in init_model_parallel_group
    return GroupCoordinator(
  File "/workspace/vllm/vllm/distributed/parallel_state.py", line 175, in __init__
    self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
  File "/workspace/vllm/vllm/distributed/device_communicators/tpu_communicator.py", line 29, in __init__
    local_rank = global_rank % local_world_size
ZeroDivisionError: integer division or modulo by zero

Because of the code here: https://github.com/vllm-project/vllm/blob/tpu-rank/vllm/distributed/device_communicators/tpu_communicator.py#L28

it causes

class TpuCommunicator:
        ....
        # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
        # must be used together. Therefore, the local rank and world size can
        # be simply calculated as follows.
        global_rank = dist.get_rank(group) # <- this global rank wil be 0 on main node
        global_world_size = dist.get_world_size(group) # <- this size should be 32 but it is 4(single node's TPU cores)
        num_nodes = len(ray.nodes()) # <- correct value, 8 for TPUv4-64
        local_world_size = global_world_size // num_nodes # this line cases err since 4 // 8 == 0, causing later line ZeroDivision Error.
        local_rank = global_rank % local_world_size
        pjrt.initialize_multiprocess(local_rank, local_world_size)
        xr._init_world_size_ordinal()

I think pytorch distributed dist.get_world_size(group) can't gather all the TPU cores information from ray cluster, rather it uses local TPU xla device only.

@Beomi
Copy link

Beomi commented Aug 28, 2024

@WoosukKwon I think this problem comes from these lines:

for ranks in group_ranks:
device_group = torch.distributed.new_group(
ranks, backend=torch_distributed_backend)
# a group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
self.device_group = device_group
self.cpu_group = cpu_group

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

I think this codes are supposed to create local cpu group for optmized rank (such as TP in 1 node) but I think this cases cpu_group to have only 4 process(=1 node TPU cores), making Zero Division err to the code above I mentioned.

is there suggested way to workaround this issue?

@DarkLight1337
Copy link
Member

LGTM, please fix the format

#7929 fixes the CI failure related to mypy, please merge from main to resolve it.

@WoosukKwon WoosukKwon merged commit e2b2aa5 into main Sep 2, 2024
19 of 21 checks passed
@WoosukKwon WoosukKwon deleted the tpu-rank branch September 2, 2024 06:09
@WoosukKwon
Copy link
Collaborator Author

@Beomi Thanks for trying out this PR. Currently, vLLM's TPU backend does not support PP. Could you please use smaller TPU pod and retry with the updated main branch?

@Beomi
Copy link

Beomi commented Sep 2, 2024

@WoosukKwon Thanks for clarification! it seems like TP over all nodes works :)
Hope PP comes to TPU as well.
Thank you 😄

@Beomi
Copy link

Beomi commented Sep 2, 2024

@WoosukKwon BTW, there is weird situation related with Multinode situation -- the code works well with TPUv4-8/v4-16/v4-32 but it fails to launch worker on one node if I run on v4-64. I'll open it on the other issue :)

Manikandan-Thangaraj-ZS0321 added a commit to Manikandan-Thangaraj-ZS0321/vllm that referenced this pull request Sep 2, 2024
[TPU] Align worker index with node boundary (vllm-project#7932)
@youkaichao
Copy link
Member

Hope PP comes to TPU as well.

TPU has very good cross-node inter connect. I don't think it needs PP.

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tpu Related to Google TPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[TPU] Make sure worker index aligns with node boundary
4 participants