Skip to content

Commit f0b3ead

Browse files
[CB][Tests] Check output of scheduling tests on Spyre (#337)
This code adds some logic to check the ouput of the scheduling steps tests. The output checking is done only on Spyre to save some compute for the cpu tests. [Important note] The output comparison checking of scheduling tests don't pass on CPU neither do they on Spyre. Exact reasons why are not sure, on cpu it might be because of high entropy of the randomly generated prompt tokens. --------- Signed-off-by: Sophie du Couédic <[email protected]> Signed-off-by: Prashant Gupta <[email protected]> Co-authored-by: Prashant Gupta <[email protected]>
1 parent aecd732 commit f0b3ead

File tree

7 files changed

+275
-146
lines changed

7 files changed

+275
-146
lines changed

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
del os.environ["VLLM_USE_V1"]
1414
# 🌶️🌶️🌶️ end hack
1515

16+
import hashlib
17+
import random
18+
1619
import pytest
1720
import torch
1821
from spyre_util import RemoteOpenAIServer, skip_unsupported_tp_size
@@ -139,3 +142,11 @@ def remote_openai_server(request):
139142
yield server
140143
except Exception as e:
141144
pytest.fail(f"Failed to setup server: {e}")
145+
146+
147+
@pytest.fixture
148+
def set_random_seed(request):
149+
func_hash = hashlib.sha256(request.node.originalname.encode('utf-8'))
150+
seed = int(func_hash.hexdigest(), 16)
151+
random.seed(seed)
152+
yield

tests/e2e/test_spyre_basic.py

Lines changed: 16 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
"""
55

66
import pytest
7-
from spyre_util import (compare_results, create_random_request,
8-
generate_hf_output, generate_spyre_vllm_output,
9-
get_chicken_soup_prompts, get_spyre_backend_list,
10-
get_spyre_model_list, skip_unsupported_tp_size)
7+
from spyre_util import (check_output_against_hf, create_random_request,
8+
generate_spyre_vllm_output, get_chicken_soup_prompts,
9+
get_spyre_backend_list, get_spyre_model_list,
10+
skip_unsupported_tp_size)
1111
from vllm import EngineArgs, SamplingParams
1212
from vllm.v1.engine.core import EngineCore
1313
from vllm.v1.executor.abstract import Executor
@@ -85,17 +85,8 @@ def test_output(
8585
backend=backend,
8686
monkeypatch=monkeypatch,
8787
**kwargs)
88-
89-
hf_results = generate_hf_output(model=model,
90-
prompts=prompts,
91-
max_new_tokens=max_new_tokens)
92-
93-
compare_results(model=model,
94-
prompts=prompts,
95-
tensor_parallel_size=tp_size,
96-
backend=backend,
97-
vllm_results=vllm_results,
98-
hf_results=hf_results)
88+
check_output_against_hf(model, backend, max_new_tokens, vllm_results,
89+
prompts)
9990

10091

10192
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -137,16 +128,8 @@ def test_output_sendnn_decoder(
137128
backend=backend,
138129
monkeypatch=monkeypatch)
139130

140-
hf_results = generate_hf_output(model=model,
141-
prompts=prompts,
142-
max_new_tokens=max_new_tokens)
143-
144-
compare_results(model=model,
145-
prompts=prompts,
146-
tensor_parallel_size=1,
147-
backend=backend,
148-
vllm_results=vllm_results,
149-
hf_results=hf_results)
131+
check_output_against_hf(model, backend, max_new_tokens, vllm_results,
132+
prompts)
150133

151134

152135
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -194,18 +177,9 @@ def test_batch_handling(model: str, backend: str, cb: int,
194177
backend=backend,
195178
monkeypatch=monkeypatch,
196179
**kwargs)
197-
hf_results = generate_hf_output(model=model,
198-
prompts=prompts,
199-
max_new_tokens=max_new_tokens)
200180

201-
compare_results(
202-
model=model,
203-
prompts=prompts,
204-
tensor_parallel_size=1,
205-
backend=backend,
206-
vllm_results=vllm_results,
207-
hf_results=hf_results,
208-
)
181+
check_output_against_hf(model, backend, max_new_tokens, vllm_results,
182+
prompts)
209183

210184

211185
@pytest.mark.parametrize("model", get_spyre_model_list())
@@ -251,9 +225,12 @@ def test_full_batch_scheduling(model: str, backend: str, monkeypatch):
251225
logprobs=0)
252226
for i in range(batch_size):
253227
engine_core.add_request(
254-
create_random_request(request_id=i,
255-
num_tokens=max_batched_tokens,
256-
sampling_params=vllm_sampling_params))
228+
create_random_request(
229+
request_id=i,
230+
num_tokens=max_batched_tokens,
231+
sampling_params=vllm_sampling_params,
232+
model=model,
233+
))
257234
schedule = scheduler.schedule()
258235

259236
assert len(schedule.scheduled_new_reqs) == batch_size

tests/e2e/test_spyre_cb_scheduler_steps.py

Lines changed: 52 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88

99
import pytest
1010
from scheduling_utils import check_scheduler_inference_steps
11-
from spyre_util import get_spyre_backend_list, get_spyre_model_list
11+
from spyre_util import (check_output_against_hf, get_spyre_backend_list,
12+
get_spyre_model_list)
1213

1314

1415
@pytest.mark.cb
1516
@pytest.mark.parametrize("model", get_spyre_model_list())
1617
@pytest.mark.parametrize("backend", get_spyre_backend_list())
1718
def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
18-
monkeypatch: pytest.MonkeyPatch):
19+
monkeypatch: pytest.MonkeyPatch,
20+
set_random_seed: None):
1921
""" Scenario where it happens that all the sequences get scheduled in a
2022
fashion where they are aligned with the block boundaries (i.e. tkv multiple
2123
of 64 at the time of prefilling).
@@ -162,7 +164,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
162164
},
163165
]
164166

165-
check_scheduler_inference_steps(
167+
cb_outputs, prompts = check_scheduler_inference_steps(
166168
model=model,
167169
backend=backend,
168170
monkeypatch=monkeypatch,
@@ -176,12 +178,16 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
176178
use_cb=True,
177179
)
178180

181+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
182+
prompts)
183+
179184

180185
@pytest.mark.cb
181186
@pytest.mark.parametrize("model", get_spyre_model_list())
182187
@pytest.mark.parametrize("backend", get_spyre_backend_list())
183188
def test_prompts_misaligned_with_tkv_boundaries(
184-
model: str, backend: str, monkeypatch: pytest.MonkeyPatch):
189+
model: str, backend: str, monkeypatch: pytest.MonkeyPatch,
190+
set_random_seed: None):
185191
""" Scenario where it happens that some sequence gets scheduled in a way
186192
that it is misaligned with the block boundary (i.e. tkv is not a multiple
187193
of 64 at the time of prefilling).
@@ -193,7 +199,6 @@ def test_prompts_misaligned_with_tkv_boundaries(
193199
* 2: len = 41, max tokens = 67, step joining = 0
194200
* 3: len = 47, max tokens = 9, step joining = 0
195201
"""
196-
197202
seqs_max_tokens = [57, 67, 9]
198203
prompts_lengths = [49, 41, 47]
199204
steps_add_reqs = [0, 0, 0] # add all requests in the beginning
@@ -326,7 +331,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
326331
},
327332
]
328333

329-
check_scheduler_inference_steps(
334+
cb_outputs, prompts = check_scheduler_inference_steps(
330335
model=model,
331336
backend=backend,
332337
monkeypatch=monkeypatch,
@@ -340,12 +345,16 @@ def test_prompts_misaligned_with_tkv_boundaries(
340345
use_cb=True,
341346
)
342347

348+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
349+
prompts)
350+
343351

344352
@pytest.mark.cb
345353
@pytest.mark.parametrize("model", get_spyre_model_list())
346354
@pytest.mark.parametrize("backend", get_spyre_backend_list())
347355
def test_two_sequences_finish_same_time_as_new_arrive(
348-
model: str, backend: str, monkeypatch: pytest.MonkeyPatch):
356+
model: str, backend: str, monkeypatch: pytest.MonkeyPatch,
357+
set_random_seed):
349358
""" 2-cases-in-1: (1) Two sequences finish at the same time and (2) a new
350359
request arrives when another finishes.
351360
@@ -356,7 +365,6 @@ def test_two_sequences_finish_same_time_as_new_arrive(
356365
* 2: len = 30, max tokens = 30, step joining = 0
357366
* 3: len = 20, max tokens = 10, step joining = 31
358367
"""
359-
360368
seqs_max_tokens = [30, 30, 10]
361369
prompts_lengths = [49, 30, 20]
362370
steps_add_reqs = [0, 0, 31]
@@ -466,7 +474,7 @@ def test_two_sequences_finish_same_time_as_new_arrive(
466474
},
467475
]
468476

469-
check_scheduler_inference_steps(
477+
cb_outputs, prompts = check_scheduler_inference_steps(
470478
model=model,
471479
backend=backend,
472480
monkeypatch=monkeypatch,
@@ -480,12 +488,16 @@ def test_two_sequences_finish_same_time_as_new_arrive(
480488
use_cb=True,
481489
)
482490

491+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
492+
prompts)
493+
483494

484495
@pytest.mark.cb
485496
@pytest.mark.parametrize("model", get_spyre_model_list())
486497
@pytest.mark.parametrize("backend", get_spyre_backend_list())
487498
def test_new_sequence_joins_during_decode(model: str, backend: str,
488-
monkeypatch: pytest.MonkeyPatch):
499+
monkeypatch: pytest.MonkeyPatch,
500+
set_random_seed):
489501
""" Scenario where a new sequence joins while decoding other sequences
490502
491503
Configuration:
@@ -731,7 +743,7 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
731743
# },
732744
]
733745

734-
check_scheduler_inference_steps(
746+
cb_outputs, prompts = check_scheduler_inference_steps(
735747
model=model,
736748
backend=backend,
737749
monkeypatch=monkeypatch,
@@ -745,12 +757,16 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
745757
use_cb=True,
746758
)
747759

760+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
761+
prompts)
762+
748763

749764
@pytest.mark.cb
750765
@pytest.mark.parametrize("model", get_spyre_model_list())
751766
@pytest.mark.parametrize("backend", get_spyre_backend_list())
752767
def test_prompt_too_long_for_current_tkv(model: str, backend: str,
753-
monkeypatch: pytest.MonkeyPatch):
768+
monkeypatch: pytest.MonkeyPatch,
769+
set_random_seed):
754770
""" Scenario where the requested prompt is too long for current tkv value
755771
756772
Configuration:
@@ -880,7 +896,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
880896
},
881897
]
882898

883-
check_scheduler_inference_steps(
899+
cb_outputs, prompts = check_scheduler_inference_steps(
884900
model=model,
885901
backend=backend,
886902
monkeypatch=monkeypatch,
@@ -894,13 +910,18 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
894910
use_cb=True,
895911
)
896912

913+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
914+
prompts)
915+
897916

898917
@pytest.mark.cb
899918
@pytest.mark.parametrize("model", get_spyre_model_list())
900919
@pytest.mark.parametrize("backend", get_spyre_backend_list())
901920
def test_requested_tokens_not_fitting_remaining_space(
902-
model: str, backend: str, monkeypatch: pytest.MonkeyPatch):
903-
""" Scenario where the request goes beyond max_model_len
921+
model: str, backend: str, monkeypatch: pytest.MonkeyPatch,
922+
set_random_seed):
923+
""" Scenario where the request goes beyond max_model_len and needs to wait
924+
for a new batch.
904925
905926
Configuration:
906927
* max_num_seqs: 2
@@ -909,7 +930,6 @@ def test_requested_tokens_not_fitting_remaining_space(
909930
* 2: len = 49, max tokens = 57, step joining = 0
910931
* 3: len = 41, max tokens = 80, step joining = 0
911932
"""
912-
913933
seqs_max_tokens = [67, 57, 80]
914934
prompts_lengths = [70, 49, 41]
915935
steps_add_reqs = [0, 0, 0]
@@ -1067,7 +1087,7 @@ def test_requested_tokens_not_fitting_remaining_space(
10671087
},
10681088
]
10691089

1070-
check_scheduler_inference_steps(
1090+
cb_outputs, prompts = check_scheduler_inference_steps(
10711091
model=model,
10721092
backend=backend,
10731093
monkeypatch=monkeypatch,
@@ -1081,12 +1101,16 @@ def test_requested_tokens_not_fitting_remaining_space(
10811101
use_cb=True,
10821102
)
10831103

1104+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
1105+
prompts)
1106+
10841107

10851108
@pytest.mark.cb
10861109
@pytest.mark.parametrize("model", get_spyre_model_list())
10871110
@pytest.mark.parametrize("backend", get_spyre_backend_list())
10881111
def test_requests_use_all_available_blocks(model: str, backend: str,
1089-
monkeypatch: pytest.MonkeyPatch):
1112+
monkeypatch: pytest.MonkeyPatch,
1113+
set_random_seed):
10901114
""" Scenario where the requests use all of the available blocks
10911115
10921116
Configuration:
@@ -1098,7 +1122,6 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
10981122
* 4: len = 10, max tokens = 3, step joining = 0
10991123
* available_blocks: 8
11001124
"""
1101-
11021125
seqs_max_tokens = [3, 3, 3, 3] # 2 decodes into a new block per sequence
11031126
prompts_lengths = [10, 10, 10, 10] # 1 block for prefil per sequence
11041127
steps_add_reqs = [0, 0, 0, 0]
@@ -1201,7 +1224,7 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
12011224
},
12021225
]
12031226

1204-
check_scheduler_inference_steps(
1227+
cb_outputs, prompts = check_scheduler_inference_steps(
12051228
model=model,
12061229
backend=backend,
12071230
monkeypatch=monkeypatch,
@@ -1215,12 +1238,16 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
12151238
use_cb=True,
12161239
)
12171240

1241+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
1242+
prompts)
1243+
12181244

12191245
@pytest.mark.cb
12201246
@pytest.mark.parametrize("model", get_spyre_model_list())
12211247
@pytest.mark.parametrize("backend", get_spyre_backend_list())
12221248
def test_requests_use_more_than_available_blocks(
1223-
model: str, backend: str, monkeypatch: pytest.MonkeyPatch):
1249+
model: str, backend: str, monkeypatch: pytest.MonkeyPatch,
1250+
set_random_seed):
12241251
""" Scenario where some request need to wait because of the number of
12251252
available blocks.
12261253
@@ -1361,7 +1388,7 @@ def test_requests_use_more_than_available_blocks(
13611388
},
13621389
]
13631390

1364-
check_scheduler_inference_steps(
1391+
cb_outputs, prompts = check_scheduler_inference_steps(
13651392
model=model,
13661393
backend=backend,
13671394
monkeypatch=monkeypatch,
@@ -1374,3 +1401,6 @@ def test_requests_use_more_than_available_blocks(
13741401
available_blocks=available_blocks,
13751402
use_cb=True,
13761403
)
1404+
1405+
check_output_against_hf(model, backend, seqs_max_tokens, cb_outputs,
1406+
prompts)

0 commit comments

Comments
 (0)