@@ -342,9 +342,13 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
342
342
)
343
343
return self ._cached_decode_metadata
344
344
345
- def advance_step (self , model_input : "ModelInputForGPUWithSamplingMetadata" ,
345
+ def advance_step (self ,
346
+ model_input : "ModelInputForGPUWithSamplingMetadata" ,
346
347
sampled_token_ids : Optional [torch .Tensor ],
347
- block_size : int , num_seqs : int , num_queries : int ):
348
+ block_size : int ,
349
+ num_seqs : int ,
350
+ num_queries : int ,
351
+ turn_prefills_into_decodes : bool = False ):
348
352
"""
349
353
Update metadata in-place to advance one decode step.
350
354
"""
@@ -355,6 +359,23 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
355
359
assert num_seqs > num_queries
356
360
assert self .use_cuda_graph
357
361
362
+ if turn_prefills_into_decodes :
363
+ # When Mutli-Step is enabled with Chunked-Prefill, prefills and
364
+ # decodes are scheduled together. In the first step, all the
365
+ # prefills turn into decodes. This update reflects that
366
+ # conversion.
367
+ assert self .num_decode_tokens + self .num_prefills == num_seqs
368
+ self .num_decode_tokens += self .num_prefills
369
+ self .num_prefills = 0
370
+ self .num_prefill_tokens = 0
371
+ self .max_prefill_seq_len = 0
372
+ self .max_query_len = 1
373
+
374
+ self .slot_mapping = self .slot_mapping [:num_seqs ]
375
+ else :
376
+ assert self .seq_lens is not None
377
+ assert self .max_decode_seq_len == max (self .seq_lens )
378
+
358
379
assert self .num_prefills == 0
359
380
assert self .num_prefill_tokens == 0
360
381
assert self .num_decode_tokens == num_seqs
@@ -366,7 +387,6 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
366
387
assert self .seq_lens_tensor .shape == (num_seqs , )
367
388
assert self .max_query_len == 1
368
389
assert self .max_prefill_seq_len == 0
369
- assert self .max_decode_seq_len == max (self .seq_lens )
370
390
371
391
assert self .query_start_loc is not None
372
392
assert self .query_start_loc .shape == (num_queries + 1 , )
@@ -706,8 +726,10 @@ def forward(
706
726
707
727
num_prefill_tokens = attn_metadata .num_prefill_tokens
708
728
num_decode_tokens = attn_metadata .num_decode_tokens
709
- assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens
710
- assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens
729
+ assert key .shape [0 ] == num_prefill_tokens + num_decode_tokens , \
730
+ f"key : { key .shape } : #prefill tokens { num_prefill_tokens } : #decode tokens { num_decode_tokens } " # noqa
731
+ assert value .shape [0 ] == num_prefill_tokens + num_decode_tokens , \
732
+ f"value : { value .shape } : #prefill toks { num_prefill_tokens } : #decode toks { num_decode_tokens } " # noqa
711
733
712
734
# Query for decode. KV is not needed because it is already cached.
713
735
decode_query = query [num_prefill_tokens :]
0 commit comments