1
1
import copy
2
- from collections import defaultdict , deque
2
+ from collections import deque
3
3
from typing import Any
4
4
5
5
import pytest
@@ -44,7 +44,6 @@ def check_scheduler_inference_steps(
44
44
max_model_len : int ,
45
45
available_blocks : int ,
46
46
use_cb : bool = True ,
47
- collect_outputs : bool = False ,
48
47
):
49
48
"""
50
49
Test the scheduler execution by comparing the scheduler attributes at each
@@ -87,8 +86,6 @@ def check_scheduler_inference_steps(
87
86
"List of checked steps needs to be of increasing order of step" )
88
87
# ------
89
88
90
- collected_outputs = defaultdict (lambda : {"tokens_ids" : [], "logprobs" : []})
91
-
92
89
# Setup the engine
93
90
engine_args = EngineArgs (model = model ,
94
91
tokenizer = model ,
@@ -111,7 +108,6 @@ def check_scheduler_inference_steps(
111
108
# after max_tokens exactly
112
109
sampling_params = SamplingParams (max_tokens = max_tokens ,
113
110
temperature = 0.0 ,
114
- logprobs = 0 ,
115
111
ignore_eos = True )
116
112
request = create_random_request (request_id = i ,
117
113
num_tokens = prompt_length ,
@@ -186,26 +182,10 @@ def check_scheduler_inference_steps(
186
182
187
183
# Perform next step
188
184
step_output = engine_core .step ()
189
- engine_core_output = step_output [0 ].get (0 )
190
- request_outputs = (engine_core_output .outputs
191
- if engine_core_output is not None else [])
192
-
193
- if collect_outputs :
194
- for output in request_outputs :
195
- new_token_ids = output .new_token_ids
196
- new_logprobs = output .new_logprobs .logprobs
197
- assert len (new_token_ids ) == 1 and len (new_logprobs ) == 1
198
-
199
- collected_outputs [output .request_id ]["tokens_ids" ].append (
200
- new_token_ids [0 ])
201
- collected_outputs [output .request_id ]["logprobs" ].append (
202
- new_logprobs [0 ][0 ])
203
-
204
- # Return collected outputs as list
205
- if not collected_outputs :
206
- return []
207
- else :
208
- output_keys = sorted (int (k ) for k in collected_outputs )
209
- assert output_keys [0 ] == 0 and output_keys [- 1 ] == len (output_keys ) - 1
210
- collected_outputs = [collected_outputs [str (k )] for k in output_keys ]
211
- return collected_outputs
185
+ # backward compatibility
186
+ if isinstance (step_output , tuple ):
187
+ engine_core_output = step_output [0 ].get (0 )
188
+ request_outputs = (engine_core_output .outputs
189
+ if engine_core_output is not None else [])
190
+ else :
191
+ request_outputs = step_output .outputs
0 commit comments