Skip to content

Commit 61a889a

Browse files
pyc96shreyankg
authored andcommitted
[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (vllm-project#14237)
Signed-off-by: pyc96 <[email protected]>
1 parent c261021 commit 61a889a

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

vllm/spec_decode/draft_model_runner.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
133133
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
134134
"""Determines if draft_model_runner GPU multi-step can be used.
135135
Currently required conditions are:
136-
1. Only decodes
136+
1. Only decodes
137137
2. Only flash-attn
138138
3. No LORA
139139
4. No prompt_adapter_config
@@ -171,12 +171,12 @@ def execute_model(
171171
num_steps: int = 1,
172172
**kwargs,
173173
) -> Optional[List[SamplerOutput]]:
174-
"""Executes num_steps forward passes with advacement of input tensors
174+
"""Executes num_steps forward passes with advacement of input tensors
175175
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
176176
177177
Optimizations used:
178178
1. Input tensors are updated on the GPU directly
179-
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
179+
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
180180
them since we do batch expansion later that uses GPU outputs)
181181
3. Reuses sampling tensors (since we run only decodes and they have
182182
a repeating sampling logic)
@@ -302,7 +302,12 @@ def execute_model(
302302
outputs.append(output)
303303

304304
if self.return_hidden_states and is_fallback:
305-
output.hidden_states = hidden_states
305+
if use_cuda_graph:
306+
indices = model_input.sampling_metadata\
307+
.selected_token_indices
308+
output.hidden_states = hidden_states[:len(indices)]
309+
else:
310+
output.hidden_states = hidden_states
306311

307312
if model_input.attn_metadata.num_prefills == 0 \
308313
and self.indices_of_seq_with_bonus_tokens is not None:

0 commit comments

Comments
 (0)