Skip to content

Commit 9db54b6

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
authored andcommitted
[Bugfix] Fix incorrect updates to num_computed_tokens in multi-step scheduling (vllm-project#9038)
Co-authored-by: Varun Sundar Rabindranath <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent 9c90c1f commit 9db54b6

File tree

6 files changed

+179
-110
lines changed

6 files changed

+179
-110
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import pytest
2+
3+
from tests.conftest import VllmRunner
4+
from tests.core.utils import create_dummy_prompt
5+
from vllm.engine.llm_engine import LLMEngine
6+
from vllm.platforms import current_platform
7+
from vllm.sequence import SequenceGroup
8+
9+
MODEL = "JackFram/llama-160m"
10+
11+
12+
def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup):
13+
scheduler = engine.scheduler[0]
14+
scheduler.add_seq_group(seq_group)
15+
16+
17+
@pytest.mark.parametrize("num_scheduler_steps", [1, 8])
18+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
19+
@pytest.mark.parametrize("enforce_eager", [False, True])
20+
def test_num_computed_tokens_update(num_scheduler_steps: int,
21+
enable_chunked_prefill: bool,
22+
enforce_eager: bool):
23+
24+
is_multi_step = num_scheduler_steps > 1
25+
is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill
26+
27+
if is_multi_step_chunked_prefill and current_platform.is_rocm():
28+
pytest.skip("Multi-step with Chunked-Prefill does not support "
29+
"rocm_flash_attn backend")
30+
31+
# Make a vllm engine
32+
runner = VllmRunner(model_name=MODEL,
33+
gpu_memory_utilization=0.7,
34+
use_v2_block_manager=True,
35+
num_scheduler_steps=num_scheduler_steps,
36+
enable_chunked_prefill=enable_chunked_prefill,
37+
enforce_eager=enforce_eager)
38+
engine: LLMEngine = runner.model.llm_engine
39+
40+
# In multi-step + chunked-prefill there is no separate single prompt step.
41+
# What is scheduled will run for num_scheduler_steps always.
42+
num_prompt_steps = num_scheduler_steps \
43+
if is_multi_step_chunked_prefill else 1
44+
45+
num_output_tokens_list = [4, 8, 12, 15, 16, 17]
46+
47+
# Create sequence and add to engine
48+
prompt_len = 10
49+
50+
for req_idx, num_output_tokens in enumerate(num_output_tokens_list):
51+
seq, seq_group = create_dummy_prompt(request_id=str(req_idx),
52+
prompt_length=prompt_len,
53+
min_tokens=num_output_tokens,
54+
max_tokens=num_output_tokens)
55+
add_seq_group_to_engine(engine, seq_group)
56+
57+
assert seq.data.get_num_computed_tokens() == 0
58+
59+
for _ in range(num_prompt_steps):
60+
# prompt steps
61+
engine.step()
62+
63+
if not seq.is_finished():
64+
prompt_num_computed_tokens = seq.data.get_num_computed_tokens()
65+
# Test correctness of num_computed_tokens after the prompt steps
66+
assert prompt_num_computed_tokens == \
67+
prompt_len + num_prompt_steps - 1
68+
69+
decode_step_counter = 0
70+
while not seq.is_finished():
71+
# Test correctness of num_computed_tokens after the decode steps
72+
assert seq.data.get_num_computed_tokens(
73+
) == prompt_num_computed_tokens + decode_step_counter
74+
for _ in range(num_scheduler_steps):
75+
# decode step
76+
engine.step()
77+
decode_step_counter += 1
78+
79+
# Test correctness of num_computed_tokens after the sequence finish.
80+
assert seq.data.get_num_computed_tokens(
81+
) == prompt_len + num_output_tokens - 1

tests/core/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def create_dummy_prompt(
1616
use_beam_search: bool = False,
1717
best_of: int = 1,
1818
prompt_tokens: Optional[List[int]] = None,
19+
min_tokens: int = 0,
20+
max_tokens: int = 16,
1921
) -> Tuple[Sequence, SequenceGroup]:
2022
if not block_size:
2123
block_size = prompt_length
@@ -36,7 +38,9 @@ def create_dummy_prompt(
3638
arrival_time=time.time(),
3739
sampling_params=SamplingParams(
3840
use_beam_search=use_beam_search,
39-
best_of=best_of),
41+
best_of=best_of,
42+
max_tokens=max_tokens,
43+
min_tokens=min_tokens),
4044
lora_request=lora_request)
4145

4246
return prompt, seq_group

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
191191
)
192192
return self._cached_decode_metadata
193193

194-
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
194+
def advance_step(self,
195+
model_input: "ModelInputForGPUWithSamplingMetadata",
195196
sampled_token_ids: Optional[torch.Tensor],
196-
block_size: int, num_seqs: int, num_queries: int):
197+
block_size: int,
198+
num_seqs: int,
199+
num_queries: int,
200+
turn_prefills_into_decodes: bool = False):
197201
"""
198202
Update metadata in-place to advance one decode step.
199203
"""
204+
205+
assert not turn_prefills_into_decodes, \
206+
("Chunked prefill is not supported with rocm_flash_attn yet."
207+
"turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill "
208+
"specific parameter.")
209+
200210
# When using cudagraph, the num_seqs is padded to the next captured
201211
# batch sized, but num_queries tracks the actual number of requests in
202212
# the batch. For --enforce-eager mode, num_seqs == num_queries

vllm/engine/llm_engine.py

Lines changed: 66 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,45 @@ def _process_sequence_group_outputs(
10091009

10101010
return
10111011

1012+
def _update_num_computed_tokens_for_multi_step_prefill(
1013+
self, seq_group: SequenceGroup,
1014+
seq_group_meta: SequenceGroupMetadata,
1015+
is_first_step_output: Optional[bool]):
1016+
"""
1017+
This function updates num_computed_tokens for prompt sequences
1018+
when Multi-Step is enabled.
1019+
1020+
seq_group: SequenceGroup to update the num_computed_tokens for.
1021+
seq_group_meta: Metadata of the given SequenceGroup.
1022+
is_first_step_output: Optional[bool] -
1023+
When available, is_first_step_output indicates if the appended
1024+
output token is the output of the first-step in multi-step.
1025+
A value of None indicates that outputs from all steps in
1026+
in multi-step are submitted in a single burst.
1027+
"""
1028+
1029+
assert self.scheduler_config.is_multi_step
1030+
1031+
if not seq_group_meta.is_prompt:
1032+
# num_computed_token updates for multi-step decodes happen after
1033+
# the tokens are appended to the sequence.
1034+
return
1035+
1036+
do_update: bool = False
1037+
if self.scheduler_config.chunked_prefill_enabled:
1038+
# In multi-step + chunked-prefill case, the prompt sequences
1039+
# that are scheduled are fully processed in the first step.
1040+
do_update = is_first_step_output is None or is_first_step_output
1041+
else:
1042+
# Normal multi-step decoding case. In this case prompt-sequences
1043+
# are actually single-stepped. Always update in this case.
1044+
assert seq_group.state.num_steps == 1
1045+
do_update = True
1046+
1047+
if do_update:
1048+
seq_group.update_num_computed_tokens(
1049+
seq_group_meta.token_chunk_size)
1050+
10121051
def _process_model_outputs(self,
10131052
ctx: SchedulerContext,
10141053
request_id: Optional[str] = None) -> None:
@@ -1019,64 +1058,6 @@ def _process_model_outputs(self,
10191058
request_id: If provided, then only this request is going to be processed
10201059
"""
10211060

1022-
def update_prefill_num_computed_tokens(
1023-
seq_group: SequenceGroup,
1024-
seq_group_meta: SequenceGroupMetadata, num_outputs: int,
1025-
is_first_step_output: Optional[bool]) -> None:
1026-
"""
1027-
When multi-step and chunked-prefill are enabled together, the
1028-
prefill sequence scheduled for multi-step execution turn into
1029-
decodes in the first step itself. This function accounts
1030-
for that conversion.
1031-
1032-
seq_group: SequenceGroup - A prefill seq_group
1033-
seq_group_meta: SequenceGroupMetadata - Metadata of the given
1034-
prefill seq_group
1035-
num_outputs: int - number of output tokens being processed for the
1036-
given seq_group
1037-
is_first_step_output: Optional[bool] -
1038-
If multi-step is enabled and num_outputs is 1, this value
1039-
indicates if this outputs belongs to the first step in the
1040-
multi-step.
1041-
If multi-step is enabled and num_outputs > 1, this value
1042-
must be None, as num_outputs > 1 indicates that outputs from
1043-
all the steps in multi-step are submitted in a single burst.
1044-
When multi-step is disabled, this value is always True.
1045-
"""
1046-
1047-
assert seq_group_meta.is_prompt
1048-
1049-
token_chunk_size = seq_group_meta.token_chunk_size
1050-
1051-
if num_outputs == 1:
1052-
assert is_first_step_output is not None
1053-
1054-
if seq_group_meta.state.num_steps == 1:
1055-
assert is_first_step_output is True
1056-
seq_group.update_num_computed_tokens(token_chunk_size)
1057-
return
1058-
1059-
# multi-step prefill is only supported when multi-step is
1060-
# enabled with chunked prefill
1061-
assert self.scheduler_config.is_multi_step and \
1062-
self.scheduler_config.chunked_prefill_enabled
1063-
if is_first_step_output is True:
1064-
# This sequence is a prompt during the first step only.
1065-
seq_group.update_num_computed_tokens(token_chunk_size)
1066-
return
1067-
1068-
assert is_first_step_output is None
1069-
1070-
# multi-step prefill is only supported when multi-step is
1071-
# enabled with chunked prefill. Outputs from all the steps are
1072-
# submitted in a single burst.
1073-
assert self.scheduler_config.is_multi_step and \
1074-
self.scheduler_config.chunked_prefill_enabled
1075-
assert num_outputs == seq_group_meta.state.num_steps, \
1076-
f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa
1077-
# This sequence is a prompt during the first step only.
1078-
seq_group.update_num_computed_tokens(token_chunk_size)
1079-
10801061
now = time.time()
10811062

10821063
if len(ctx.output_queue) == 0:
@@ -1137,7 +1118,7 @@ def update_prefill_num_computed_tokens(
11371118
seq_group_meta = seq_group_metadata_list[i]
11381119
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]
11391120

1140-
seq_group = scheduled_seq_group.seq_group
1121+
seq_group: SequenceGroup = scheduled_seq_group.seq_group
11411122

11421123
if seq_group.is_finished():
11431124
finished_before.append(i)
@@ -1148,14 +1129,14 @@ def update_prefill_num_computed_tokens(
11481129
else:
11491130
output = [outputs_by_sequence_group[0][i]]
11501131

1151-
if not is_async and seq_group_meta.is_prompt:
1152-
# Updates for all decodes happen when we actually append the
1153-
# token ids to the seq in process_outputs.
1154-
update_prefill_num_computed_tokens(seq_group, seq_group_meta,
1155-
len(output),
1156-
is_first_step_output)
1157-
elif not is_async:
1158-
seq_group.update_num_computed_tokens(1)
1132+
if not is_async:
1133+
if self.scheduler_config.is_multi_step:
1134+
# Updates happen only if the sequence is prefill
1135+
self._update_num_computed_tokens_for_multi_step_prefill(
1136+
seq_group, seq_group_meta, is_first_step_output)
1137+
else:
1138+
seq_group.update_num_computed_tokens(
1139+
seq_group_meta.token_chunk_size)
11591140

11601141
if outputs:
11611142
for o in outputs:
@@ -1179,16 +1160,8 @@ def update_prefill_num_computed_tokens(
11791160
else:
11801161
self.output_processor.process_prompt_logprob(seq_group, output)
11811162
if seq_group_meta.do_sample:
1182-
output_token_num = self.output_processor.process_outputs(
1163+
self.output_processor.process_outputs(
11831164
seq_group, output, is_async)
1184-
if self.speculative_config:
1185-
# We -1 here because we always
1186-
# (w/o speculative decoding) add the number of
1187-
# computed tokens by one in the decoding phase.
1188-
# Therefore, we remove that one token that
1189-
# is already added.
1190-
seq_group.update_num_computed_tokens(output_token_num -
1191-
1)
11921165

11931166
if seq_group.is_finished():
11941167
finished_now.append(i)
@@ -1297,20 +1270,15 @@ def _advance_to_next_step(
12971270
if seq_group.is_finished():
12981271
continue
12991272

1300-
if seq_group_metadata.is_prompt:
1301-
if self.scheduler_config.is_multi_step and \
1302-
self.scheduler_config.chunked_prefill_enabled:
1303-
# Prompts are scheduled in multi-step only when
1304-
# chunking is enabled. These prompts turn into
1305-
# decodes after the very first step. Therefore,
1306-
# we skip the update to the num_computed_tokens
1307-
# here.
1308-
seq_group.update_num_computed_tokens(1)
1309-
else:
1310-
seq_group.update_num_computed_tokens(
1311-
seq_group_metadata.token_chunk_size)
1273+
if self.scheduler_config.is_multi_step:
1274+
# Updates happen only if the sequence is prefill
1275+
self._update_num_computed_tokens_for_multi_step_prefill(
1276+
seq_group, seq_group_metadata,
1277+
seq_group.state.num_steps == 1)
13121278
else:
1313-
seq_group.update_num_computed_tokens(1)
1279+
seq_group.update_num_computed_tokens(
1280+
seq_group_metadata.token_chunk_size)
1281+
13141282
if seq_group_metadata.do_sample:
13151283
assert len(sequence_group_outputs.samples) == 1, (
13161284
"Async output processor expects a single sample"
@@ -1320,7 +1288,15 @@ def _advance_to_next_step(
13201288

13211289
assert len(seq_group.seqs) == 1
13221290
seq = seq_group.seqs[0]
1323-
seq.append_token_id(sample.output_token, sample.logprobs)
1291+
1292+
if self.scheduler_config.is_multi_step:
1293+
is_prefill_append = seq.data.get_num_uncomputed_tokens(
1294+
) == 0
1295+
seq.append_token_id(sample.output_token, sample.logprobs)
1296+
if not is_prefill_append:
1297+
seq_group.update_num_computed_tokens(1)
1298+
else:
1299+
seq.append_token_id(sample.output_token, sample.logprobs)
13241300

13251301
def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
13261302
"""Performs one decoding iteration and returns newly generated results.

vllm/engine/output_processor/interfaces.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Callable, List, Optional
2+
from typing import Callable, List
33

44
from vllm.config import SchedulerConfig
55
from vllm.core.scheduler import Scheduler
@@ -58,14 +58,10 @@ def create_output_processor(
5858
@abstractmethod
5959
def process_outputs(self, sequence_group: SequenceGroup,
6060
outputs: List[SequenceGroupOutput],
61-
is_async: bool) -> Optional[int]:
61+
is_async: bool) -> None:
6262
"""Process new token ids for the sequence group. Handles logic such as
6363
detokenization, stop checking, and freeing/forking sequences in the
6464
scheduler.
65-
66-
Return the number of new tokens generated in the sequence group.
67-
The returned value is optional because it is only used for
68-
speculative decoding mqa scorer.
6965
"""
7066
pass
7167

0 commit comments

Comments
 (0)