Skip to content

Commit bb21f76

Browse files
njhillpaulpak58
authored andcommitted
[BugFix] Fix KVConnector TP worker aggregation (vllm-project#21473)
Signed-off-by: Nick Hill <[email protected]> Signed-off-by: Paul Pak <[email protected]>
1 parent 4727865 commit bb21f76

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

vllm/v1/worker/gpu_worker.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from vllm.distributed import (ensure_model_parallel_initialized,
1717
init_distributed_environment,
1818
set_custom_all_reduce)
19-
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
19+
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
20+
has_kv_transfer_group)
2021
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
2122
from vllm.logger import init_logger
2223
from vllm.lora.request import LoRARequest
@@ -342,19 +343,20 @@ def execute_model(
342343
assert isinstance(output, IntermediateTensors)
343344
get_pp_group().send_tensor_dict(output.tensors,
344345
all_gather_group=get_tp_group())
346+
if not has_kv_transfer_group():
347+
return None
345348

346349
# In case of PP with kv transfer, we need to pass through the
347350
# finished_sending and finished_recving buffers.
348-
empty_output = EMPTY_MODEL_RUNNER_OUTPUT
351+
new_output = EMPTY_MODEL_RUNNER_OUTPUT
349352
if output.finished_sending or output.finished_recving:
350-
empty_output = copy.copy(empty_output)
351-
empty_output.finished_sending = output.finished_sending
352-
empty_output.finished_recving = output.finished_recving
353-
output = empty_output
353+
new_output = copy.copy(new_output)
354+
new_output.finished_sending = output.finished_sending
355+
new_output.finished_recving = output.finished_recving
356+
output = new_output
354357

355358
assert isinstance(output, ModelRunnerOutput)
356-
# return output only from the driver worker
357-
return output if self.is_driver_worker else None
359+
return output
358360

359361
def profile(self, is_start: bool = True):
360362
if self.profiler is None:

0 commit comments

Comments
 (0)