Skip to content

Commit 843e7ca

Browse files
committed
Revert "collect output"
This reverts commit 0868f1e. Signed-off-by: Sophie du Couédic <[email protected]>
1 parent 0868f1e commit 843e7ca

File tree

1 file changed

+8
-28
lines changed

1 file changed

+8
-28
lines changed

tests/scheduling_utils.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from collections import defaultdict, deque
2+
from collections import deque
33
from typing import Any
44

55
import pytest
@@ -44,7 +44,6 @@ def check_scheduler_inference_steps(
4444
max_model_len: int,
4545
available_blocks: int,
4646
use_cb: bool = True,
47-
collect_outputs: bool = False,
4847
):
4948
"""
5049
Test the scheduler execution by comparing the scheduler attributes at each
@@ -87,8 +86,6 @@ def check_scheduler_inference_steps(
8786
"List of checked steps needs to be of increasing order of step")
8887
# ------
8988

90-
collected_outputs = defaultdict(lambda: {"tokens_ids": [], "logprobs": []})
91-
9289
# Setup the engine
9390
engine_args = EngineArgs(model=model,
9491
tokenizer=model,
@@ -111,7 +108,6 @@ def check_scheduler_inference_steps(
111108
# after max_tokens exactly
112109
sampling_params = SamplingParams(max_tokens=max_tokens,
113110
temperature=0.0,
114-
logprobs=0,
115111
ignore_eos=True)
116112
request = create_random_request(request_id=i,
117113
num_tokens=prompt_length,
@@ -186,26 +182,10 @@ def check_scheduler_inference_steps(
186182

187183
# Perform next step
188184
step_output = engine_core.step()
189-
engine_core_output = step_output[0].get(0)
190-
request_outputs = (engine_core_output.outputs
191-
if engine_core_output is not None else [])
192-
193-
if collect_outputs:
194-
for output in request_outputs:
195-
new_token_ids = output.new_token_ids
196-
new_logprobs = output.new_logprobs.logprobs
197-
assert len(new_token_ids) == 1 and len(new_logprobs) == 1
198-
199-
collected_outputs[output.request_id]["tokens_ids"].append(
200-
new_token_ids[0])
201-
collected_outputs[output.request_id]["logprobs"].append(
202-
new_logprobs[0][0])
203-
204-
# Return collected outputs as list
205-
if not collected_outputs:
206-
return []
207-
else:
208-
output_keys = sorted(int(k) for k in collected_outputs)
209-
assert output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1
210-
collected_outputs = [collected_outputs[str(k)] for k in output_keys]
211-
return collected_outputs
185+
# backward compatibility
186+
if isinstance(step_output, tuple):
187+
engine_core_output = step_output[0].get(0)
188+
request_outputs = (engine_core_output.outputs
189+
if engine_core_output is not None else [])
190+
else:
191+
request_outputs = step_output.outputs

0 commit comments

Comments
 (0)