@@ -1009,6 +1009,45 @@ def _process_sequence_group_outputs(
1009
1009
1010
1010
return
1011
1011
1012
+ def _update_num_computed_tokens_for_multi_step_prefill (
1013
+ self , seq_group : SequenceGroup ,
1014
+ seq_group_meta : SequenceGroupMetadata ,
1015
+ is_first_step_output : Optional [bool ]):
1016
+ """
1017
+ This function updates num_computed_tokens for prompt sequences
1018
+ when Multi-Step is enabled.
1019
+
1020
+ seq_group: SequenceGroup to update the num_computed_tokens for.
1021
+ seq_group_meta: Metadata of the given SequenceGroup.
1022
+ is_first_step_output: Optional[bool] -
1023
+ When available, is_first_step_output indicates if the appended
1024
+ output token is the output of the first-step in multi-step.
1025
+ A value of None indicates that outputs from all steps in
1026
+ in multi-step are submitted in a single burst.
1027
+ """
1028
+
1029
+ assert self .scheduler_config .is_multi_step
1030
+
1031
+ if not seq_group_meta .is_prompt :
1032
+ # num_computed_token updates for multi-step decodes happen after
1033
+ # the tokens are appended to the sequence.
1034
+ return
1035
+
1036
+ do_update : bool = False
1037
+ if self .scheduler_config .chunked_prefill_enabled :
1038
+ # In multi-step + chunked-prefill case, the prompt sequences
1039
+ # that are scheduled are fully processed in the first step.
1040
+ do_update = is_first_step_output is None or is_first_step_output
1041
+ else :
1042
+ # Normal multi-step decoding case. In this case prompt-sequences
1043
+ # are actually single-stepped. Always update in this case.
1044
+ assert seq_group .state .num_steps == 1
1045
+ do_update = True
1046
+
1047
+ if do_update :
1048
+ seq_group .update_num_computed_tokens (
1049
+ seq_group_meta .token_chunk_size )
1050
+
1012
1051
def _process_model_outputs (self ,
1013
1052
ctx : SchedulerContext ,
1014
1053
request_id : Optional [str ] = None ) -> None :
@@ -1019,64 +1058,6 @@ def _process_model_outputs(self,
1019
1058
request_id: If provided, then only this request is going to be processed
1020
1059
"""
1021
1060
1022
- def update_prefill_num_computed_tokens (
1023
- seq_group : SequenceGroup ,
1024
- seq_group_meta : SequenceGroupMetadata , num_outputs : int ,
1025
- is_first_step_output : Optional [bool ]) -> None :
1026
- """
1027
- When multi-step and chunked-prefill are enabled together, the
1028
- prefill sequence scheduled for multi-step execution turn into
1029
- decodes in the first step itself. This function accounts
1030
- for that conversion.
1031
-
1032
- seq_group: SequenceGroup - A prefill seq_group
1033
- seq_group_meta: SequenceGroupMetadata - Metadata of the given
1034
- prefill seq_group
1035
- num_outputs: int - number of output tokens being processed for the
1036
- given seq_group
1037
- is_first_step_output: Optional[bool] -
1038
- If multi-step is enabled and num_outputs is 1, this value
1039
- indicates if this outputs belongs to the first step in the
1040
- multi-step.
1041
- If multi-step is enabled and num_outputs > 1, this value
1042
- must be None, as num_outputs > 1 indicates that outputs from
1043
- all the steps in multi-step are submitted in a single burst.
1044
- When multi-step is disabled, this value is always True.
1045
- """
1046
-
1047
- assert seq_group_meta .is_prompt
1048
-
1049
- token_chunk_size = seq_group_meta .token_chunk_size
1050
-
1051
- if num_outputs == 1 :
1052
- assert is_first_step_output is not None
1053
-
1054
- if seq_group_meta .state .num_steps == 1 :
1055
- assert is_first_step_output is True
1056
- seq_group .update_num_computed_tokens (token_chunk_size )
1057
- return
1058
-
1059
- # multi-step prefill is only supported when multi-step is
1060
- # enabled with chunked prefill
1061
- assert self .scheduler_config .is_multi_step and \
1062
- self .scheduler_config .chunked_prefill_enabled
1063
- if is_first_step_output is True :
1064
- # This sequence is a prompt during the first step only.
1065
- seq_group .update_num_computed_tokens (token_chunk_size )
1066
- return
1067
-
1068
- assert is_first_step_output is None
1069
-
1070
- # multi-step prefill is only supported when multi-step is
1071
- # enabled with chunked prefill. Outputs from all the steps are
1072
- # submitted in a single burst.
1073
- assert self .scheduler_config .is_multi_step and \
1074
- self .scheduler_config .chunked_prefill_enabled
1075
- assert num_outputs == seq_group_meta .state .num_steps , \
1076
- f"#outputs { len (outputs )} - num steps { seq_group_meta .state .num_steps } " #noqa
1077
- # This sequence is a prompt during the first step only.
1078
- seq_group .update_num_computed_tokens (token_chunk_size )
1079
-
1080
1061
now = time .time ()
1081
1062
1082
1063
if len (ctx .output_queue ) == 0 :
@@ -1137,7 +1118,7 @@ def update_prefill_num_computed_tokens(
1137
1118
seq_group_meta = seq_group_metadata_list [i ]
1138
1119
scheduled_seq_group = scheduler_outputs .scheduled_seq_groups [i ]
1139
1120
1140
- seq_group = scheduled_seq_group .seq_group
1121
+ seq_group : SequenceGroup = scheduled_seq_group .seq_group
1141
1122
1142
1123
if seq_group .is_finished ():
1143
1124
finished_before .append (i )
@@ -1148,14 +1129,14 @@ def update_prefill_num_computed_tokens(
1148
1129
else :
1149
1130
output = [outputs_by_sequence_group [0 ][i ]]
1150
1131
1151
- if not is_async and seq_group_meta . is_prompt :
1152
- # Updates for all decodes happen when we actually append the
1153
- # token ids to the seq in process_outputs.
1154
- update_prefill_num_computed_tokens ( seq_group , seq_group_meta ,
1155
- len ( output ),
1156
- is_first_step_output )
1157
- elif not is_async :
1158
- seq_group . update_num_computed_tokens ( 1 )
1132
+ if not is_async :
1133
+ if self . scheduler_config . is_multi_step :
1134
+ # Updates happen only if the sequence is prefill
1135
+ self . _update_num_computed_tokens_for_multi_step_prefill (
1136
+ seq_group , seq_group_meta , is_first_step_output )
1137
+ else :
1138
+ seq_group . update_num_computed_tokens (
1139
+ seq_group_meta . token_chunk_size )
1159
1140
1160
1141
if outputs :
1161
1142
for o in outputs :
@@ -1179,16 +1160,8 @@ def update_prefill_num_computed_tokens(
1179
1160
else :
1180
1161
self .output_processor .process_prompt_logprob (seq_group , output )
1181
1162
if seq_group_meta .do_sample :
1182
- output_token_num = self .output_processor .process_outputs (
1163
+ self .output_processor .process_outputs (
1183
1164
seq_group , output , is_async )
1184
- if self .speculative_config :
1185
- # We -1 here because we always
1186
- # (w/o speculative decoding) add the number of
1187
- # computed tokens by one in the decoding phase.
1188
- # Therefore, we remove that one token that
1189
- # is already added.
1190
- seq_group .update_num_computed_tokens (output_token_num -
1191
- 1 )
1192
1165
1193
1166
if seq_group .is_finished ():
1194
1167
finished_now .append (i )
@@ -1297,20 +1270,15 @@ def _advance_to_next_step(
1297
1270
if seq_group .is_finished ():
1298
1271
continue
1299
1272
1300
- if seq_group_metadata .is_prompt :
1301
- if self .scheduler_config .is_multi_step and \
1302
- self .scheduler_config .chunked_prefill_enabled :
1303
- # Prompts are scheduled in multi-step only when
1304
- # chunking is enabled. These prompts turn into
1305
- # decodes after the very first step. Therefore,
1306
- # we skip the update to the num_computed_tokens
1307
- # here.
1308
- seq_group .update_num_computed_tokens (1 )
1309
- else :
1310
- seq_group .update_num_computed_tokens (
1311
- seq_group_metadata .token_chunk_size )
1273
+ if self .scheduler_config .is_multi_step :
1274
+ # Updates happen only if the sequence is prefill
1275
+ self ._update_num_computed_tokens_for_multi_step_prefill (
1276
+ seq_group , seq_group_metadata ,
1277
+ seq_group .state .num_steps == 1 )
1312
1278
else :
1313
- seq_group .update_num_computed_tokens (1 )
1279
+ seq_group .update_num_computed_tokens (
1280
+ seq_group_metadata .token_chunk_size )
1281
+
1314
1282
if seq_group_metadata .do_sample :
1315
1283
assert len (sequence_group_outputs .samples ) == 1 , (
1316
1284
"Async output processor expects a single sample"
@@ -1320,7 +1288,15 @@ def _advance_to_next_step(
1320
1288
1321
1289
assert len (seq_group .seqs ) == 1
1322
1290
seq = seq_group .seqs [0 ]
1323
- seq .append_token_id (sample .output_token , sample .logprobs )
1291
+
1292
+ if self .scheduler_config .is_multi_step :
1293
+ is_prefill_append = seq .data .get_num_uncomputed_tokens (
1294
+ ) == 0
1295
+ seq .append_token_id (sample .output_token , sample .logprobs )
1296
+ if not is_prefill_append :
1297
+ seq_group .update_num_computed_tokens (1 )
1298
+ else :
1299
+ seq .append_token_id (sample .output_token , sample .logprobs )
1324
1300
1325
1301
def step (self ) -> List [Union [RequestOutput , EmbeddingRequestOutput ]]:
1326
1302
"""Performs one decoding iteration and returns newly generated results.
0 commit comments