Skip to content

Commit 0868f1e

Browse files
committed
collect output
Signed-off-by: Sophie du Couédic <[email protected]>
1 parent 6e1f33a commit 0868f1e

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

tests/scheduling_utils.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import copy
2-
from collections import deque
2+
from collections import defaultdict, deque
33
from typing import Any
44

55
import pytest
@@ -44,6 +44,7 @@ def check_scheduler_inference_steps(
4444
max_model_len: int,
4545
available_blocks: int,
4646
use_cb: bool = True,
47+
collect_outputs: bool = False,
4748
):
4849
"""
4950
Test the scheduler execution by comparing the scheduler attributes at each
@@ -86,6 +87,8 @@ def check_scheduler_inference_steps(
8687
"List of checked steps needs to be of increasing order of step")
8788
# ------
8889

90+
collected_outputs = defaultdict(lambda: {"tokens_ids": [], "logprobs": []})
91+
8992
# Setup the engine
9093
engine_args = EngineArgs(model=model,
9194
tokenizer=model,
@@ -108,6 +111,7 @@ def check_scheduler_inference_steps(
108111
# after max_tokens exactly
109112
sampling_params = SamplingParams(max_tokens=max_tokens,
110113
temperature=0.0,
114+
logprobs=0,
111115
ignore_eos=True)
112116
request = create_random_request(request_id=i,
113117
num_tokens=prompt_length,
@@ -182,10 +186,26 @@ def check_scheduler_inference_steps(
182186

183187
# Perform next step
184188
step_output = engine_core.step()
185-
# backward compatibility
186-
if isinstance(step_output, tuple):
187-
engine_core_output = step_output[0].get(0)
188-
request_outputs = (engine_core_output.outputs
189-
if engine_core_output is not None else [])
190-
else:
191-
request_outputs = step_output.outputs
189+
engine_core_output = step_output[0].get(0)
190+
request_outputs = (engine_core_output.outputs
191+
if engine_core_output is not None else [])
192+
193+
if collect_outputs:
194+
for output in request_outputs:
195+
new_token_ids = output.new_token_ids
196+
new_logprobs = output.new_logprobs.logprobs
197+
assert len(new_token_ids) == 1 and len(new_logprobs) == 1
198+
199+
collected_outputs[output.request_id]["tokens_ids"].append(
200+
new_token_ids[0])
201+
collected_outputs[output.request_id]["logprobs"].append(
202+
new_logprobs[0][0])
203+
204+
# Return collected outputs as list
205+
if not collected_outputs:
206+
return []
207+
else:
208+
output_keys = sorted(int(k) for k in collected_outputs)
209+
assert output_keys[0] == 0 and output_keys[-1] == len(output_keys) - 1
210+
collected_outputs = [collected_outputs[str(k)] for k in output_keys]
211+
return collected_outputs

0 commit comments

Comments
 (0)