7
7
8
8
from vllm .executor .distributed_gpu_executor import ( # yapf: disable
9
9
DistributedGPUExecutor , DistributedGPUExecutorAsync )
10
+ from vllm .executor .gpu_executor import create_worker
10
11
from vllm .executor .multiproc_worker_utils import (ProcessWorkerWrapper ,
11
12
ResultHandler , WorkerMonitor )
12
13
from vllm .logger import init_logger
13
14
from vllm .sequence import ExecuteModelRequest , SamplerOutput
14
15
from vllm .triton_utils import maybe_set_triton_cache_manager
15
- from vllm .utils import (cuda_device_count_stateless ,
16
+ from vllm .utils import (_run_task_with_lock , cuda_device_count_stateless ,
16
17
error_on_invalid_device_count_status ,
17
18
get_distributed_init_method , get_open_port ,
18
19
get_vllm_instance_id , make_async ,
@@ -26,7 +27,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
26
27
27
28
def _init_executor (self ) -> None :
28
29
# Create the parallel GPU workers.
29
- world_size = self .parallel_config .tensor_parallel_size
30
+ world_size = self .parallel_config .world_size
31
+ tensor_parallel_size = self .parallel_config .tensor_parallel_size
30
32
31
33
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
32
34
if "CUDA_VISIBLE_DEVICES" not in os .environ :
@@ -49,8 +51,15 @@ def _init_executor(self) -> None:
49
51
if world_size > 1 :
50
52
maybe_set_triton_cache_manager ()
51
53
52
- assert world_size <= cuda_device_count_stateless (), (
53
- "please set tensor_parallel_size to less than max local gpu count" )
54
+ cuda_device_count = cuda_device_count_stateless ()
55
+ # Use confusing message for more common TP-only case.
56
+ assert tensor_parallel_size <= cuda_device_count , (
57
+ f"please set tensor_parallel_size ({ tensor_parallel_size } ) "
58
+ f"to less than max local gpu count ({ cuda_device_count } )" )
59
+
60
+ assert world_size <= cuda_device_count , (
61
+ f"please ensure that world_size ({ world_size } ) "
62
+ f"is less than than max local gpu count ({ cuda_device_count } )" )
54
63
55
64
error_on_invalid_device_count_status ()
56
65
@@ -60,21 +69,35 @@ def _init_executor(self) -> None:
60
69
distributed_init_method = get_distributed_init_method (
61
70
"127.0.0.1" , get_open_port ())
62
71
72
+ self .workers : List [ProcessWorkerWrapper ] = []
73
+ # This is the list of workers that are rank 0 of each TP group EXCEPT
74
+ # global rank 0. These are the workers that will broadcast to the
75
+ # rest of the workers.
76
+ self .tp_driver_workers : List [ProcessWorkerWrapper ] = []
77
+ # This is the list of workers that are not drivers and not the first
78
+ # worker in a TP group. These are the workers that will be
79
+ # broadcasted to.
80
+ self .non_driver_workers : List [ProcessWorkerWrapper ] = []
81
+
63
82
if world_size == 1 :
64
- self .workers = []
65
83
self .worker_monitor = None
66
84
else :
67
85
result_handler = ResultHandler ()
68
- self . workers = [
69
- ProcessWorkerWrapper (
86
+ for rank in range ( 1 , world_size ):
87
+ worker = ProcessWorkerWrapper (
70
88
result_handler ,
71
89
partial (
72
- self ._create_worker ,
73
- rank = rank ,
74
- local_rank = rank ,
75
- distributed_init_method = distributed_init_method ,
76
- )) for rank in range (1 , world_size )
77
- ]
90
+ create_worker ,
91
+ ** self ._get_create_worker_kwargs (
92
+ rank = rank ,
93
+ local_rank = rank ,
94
+ distributed_init_method = distributed_init_method ,
95
+ )))
96
+ self .workers .append (worker )
97
+ if rank % tensor_parallel_size == 0 :
98
+ self .tp_driver_workers .append (worker )
99
+ else :
100
+ self .non_driver_workers .append (worker )
78
101
79
102
self .worker_monitor = WorkerMonitor (self .workers , result_handler )
80
103
result_handler .start ()
@@ -136,16 +159,19 @@ def _run_workers(
136
159
raise NotImplementedError (
137
160
"max_concurrent_workers is not supported yet." )
138
161
139
- # Start the workers first.
162
+ if async_run_tensor_parallel_workers_only :
163
+ # Run only non-driver workers and just return futures.
164
+ return [
165
+ worker .execute_method (method , * args , ** kwargs )
166
+ for worker in self .non_driver_workers
167
+ ]
168
+
169
+ # Start all remote workers first.
140
170
worker_outputs = [
141
171
worker .execute_method (method , * args , ** kwargs )
142
172
for worker in self .workers
143
173
]
144
174
145
- if async_run_tensor_parallel_workers_only :
146
- # Just return futures
147
- return worker_outputs
148
-
149
175
driver_worker_method = getattr (self .driver_worker , method )
150
176
driver_worker_output = driver_worker_method (* args , ** kwargs )
151
177
@@ -172,16 +198,45 @@ class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor,
172
198
def __init__ (self , * args , ** kwargs ):
173
199
super ().__init__ (* args , ** kwargs )
174
200
self .driver_exec_model = make_async (self .driver_worker .execute_model )
201
+ self .pp_locks : Optional [List [asyncio .Lock ]] = None
175
202
176
203
async def _driver_execute_model_async (
177
204
self ,
178
205
execute_model_req : Optional [ExecuteModelRequest ] = None
179
206
) -> List [SamplerOutput ]:
180
- return await self .driver_exec_model (execute_model_req )
207
+ if not self .tp_driver_workers :
208
+ return await self .driver_exec_model (execute_model_req )
209
+
210
+ if self .pp_locks is None :
211
+ # This locks each pipeline parallel stage so multiple virtual
212
+ # engines can't execute on the same stage at the same time
213
+ # We create the locks here to avoid creating them in the constructor
214
+ # which uses a different asyncio loop.
215
+ self .pp_locks = [
216
+ asyncio .Lock ()
217
+ for _ in range (self .parallel_config .pipeline_parallel_size )
218
+ ]
219
+
220
+ tasks = [
221
+ asyncio .create_task (
222
+ _run_task_with_lock (self .driver_exec_model , self .pp_locks [0 ],
223
+ execute_model_req ))
224
+ ]
225
+ for pp_rank , driver_worker in enumerate (self .tp_driver_workers ,
226
+ start = 1 ):
227
+ tasks .append (
228
+ asyncio .create_task (
229
+ _run_task_with_lock (driver_worker .execute_method_async ,
230
+ self .pp_locks [pp_rank ],
231
+ "execute_model" , execute_model_req )))
232
+ results = await asyncio .gather (* tasks )
233
+
234
+ # Only the last PP stage has the final results.
235
+ return results [- 1 ]
181
236
182
237
async def _start_worker_execution_loop (self ):
183
238
coros = [
184
239
worker .execute_method_async ("start_worker_execution_loop" )
185
- for worker in self .workers
240
+ for worker in self .non_driver_workers
186
241
]
187
242
return await asyncio .gather (* coros )
0 commit comments