@@ -133,7 +133,7 @@ def _gpu_advance_step(self, model_input: ModelRunnerInputBase,
133
133
def supports_gpu_multi_step (self , execute_model_req : ExecuteModelRequest ):
134
134
"""Determines if draft_model_runner GPU multi-step can be used.
135
135
Currently required conditions are:
136
- 1. Only decodes
136
+ 1. Only decodes
137
137
2. Only flash-attn
138
138
3. No LORA
139
139
4. No prompt_adapter_config
@@ -171,12 +171,12 @@ def execute_model(
171
171
num_steps : int = 1 ,
172
172
** kwargs ,
173
173
) -> Optional [List [SamplerOutput ]]:
174
- """Executes num_steps forward passes with advacement of input tensors
174
+ """Executes num_steps forward passes with advacement of input tensors
175
175
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
176
176
177
177
Optimizations used:
178
178
1. Input tensors are updated on the GPU directly
179
- 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
179
+ 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
180
180
them since we do batch expansion later that uses GPU outputs)
181
181
3. Reuses sampling tensors (since we run only decodes and they have
182
182
a repeating sampling logic)
@@ -302,7 +302,12 @@ def execute_model(
302
302
outputs .append (output )
303
303
304
304
if self .return_hidden_states and is_fallback :
305
- output .hidden_states = hidden_states
305
+ if use_cuda_graph :
306
+ indices = model_input .sampling_metadata \
307
+ .selected_token_indices
308
+ output .hidden_states = hidden_states [:len (indices )]
309
+ else :
310
+ output .hidden_states = hidden_states
306
311
307
312
if model_input .attn_metadata .num_prefills == 0 \
308
313
and self .indices_of_seq_with_bonus_tokens is not None :
0 commit comments