1
1
import copy
2
- from collections import deque
2
+ from collections import defaultdict , deque
3
3
from typing import Any
4
4
5
5
import pytest
@@ -44,6 +44,7 @@ def check_scheduler_inference_steps(
44
44
max_model_len : int ,
45
45
available_blocks : int ,
46
46
use_cb : bool = True ,
47
+ collect_outputs : bool = False ,
47
48
):
48
49
"""
49
50
Test the scheduler execution by comparing the scheduler attributes at each
@@ -86,6 +87,8 @@ def check_scheduler_inference_steps(
86
87
"List of checked steps needs to be of increasing order of step" )
87
88
# ------
88
89
90
+ collected_outputs = defaultdict (lambda : {"tokens_ids" : [], "logprobs" : []})
91
+
89
92
# Setup the engine
90
93
engine_args = EngineArgs (model = model ,
91
94
tokenizer = model ,
@@ -108,6 +111,7 @@ def check_scheduler_inference_steps(
108
111
# after max_tokens exactly
109
112
sampling_params = SamplingParams (max_tokens = max_tokens ,
110
113
temperature = 0.0 ,
114
+ logprobs = 0 ,
111
115
ignore_eos = True )
112
116
request = create_random_request (request_id = i ,
113
117
num_tokens = prompt_length ,
@@ -182,10 +186,26 @@ def check_scheduler_inference_steps(
182
186
183
187
# Perform next step
184
188
step_output = engine_core .step ()
185
- # backward compatibility
186
- if isinstance (step_output , tuple ):
187
- engine_core_output = step_output [0 ].get (0 )
188
- request_outputs = (engine_core_output .outputs
189
- if engine_core_output is not None else [])
190
- else :
191
- request_outputs = step_output .outputs
189
+ engine_core_output = step_output [0 ].get (0 )
190
+ request_outputs = (engine_core_output .outputs
191
+ if engine_core_output is not None else [])
192
+
193
+ if collect_outputs :
194
+ for output in request_outputs :
195
+ new_token_ids = output .new_token_ids
196
+ new_logprobs = output .new_logprobs .logprobs
197
+ assert len (new_token_ids ) == 1 and len (new_logprobs ) == 1
198
+
199
+ collected_outputs [output .request_id ]["tokens_ids" ].append (
200
+ new_token_ids [0 ])
201
+ collected_outputs [output .request_id ]["logprobs" ].append (
202
+ new_logprobs [0 ][0 ])
203
+
204
+ # Return collected outputs as list
205
+ if not collected_outputs :
206
+ return []
207
+ else :
208
+ output_keys = sorted (int (k ) for k in collected_outputs )
209
+ assert output_keys [0 ] == 0 and output_keys [- 1 ] == len (output_keys ) - 1
210
+ collected_outputs = [collected_outputs [str (k )] for k in output_keys ]
211
+ return collected_outputs
0 commit comments