8
8
9
9
import pytest
10
10
from scheduling_utils import check_scheduler_inference_steps
11
- from spyre_util import get_spyre_backend_list , get_spyre_model_list
11
+ from spyre_util import (check_output_against_hf , get_spyre_backend_list ,
12
+ get_spyre_model_list )
12
13
13
14
14
15
@pytest .mark .cb
15
16
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
16
17
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
17
18
def test_prompts_aligned_with_tkv_boundaries (model : str , backend : str ,
18
- monkeypatch : pytest .MonkeyPatch ):
19
+ monkeypatch : pytest .MonkeyPatch ,
20
+ set_random_seed : None ):
19
21
""" Scenario where it happens that all the sequences get scheduled in a
20
22
fashion where they are aligned with the block boundaries (i.e. tkv multiple
21
23
of 64 at the time of prefilling).
@@ -162,7 +164,7 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
162
164
},
163
165
]
164
166
165
- check_scheduler_inference_steps (
167
+ cb_outputs , prompts = check_scheduler_inference_steps (
166
168
model = model ,
167
169
backend = backend ,
168
170
monkeypatch = monkeypatch ,
@@ -176,12 +178,16 @@ def test_prompts_aligned_with_tkv_boundaries(model: str, backend: str,
176
178
use_cb = True ,
177
179
)
178
180
181
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
182
+ prompts )
183
+
179
184
180
185
@pytest .mark .cb
181
186
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
182
187
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
183
188
def test_prompts_misaligned_with_tkv_boundaries (
184
- model : str , backend : str , monkeypatch : pytest .MonkeyPatch ):
189
+ model : str , backend : str , monkeypatch : pytest .MonkeyPatch ,
190
+ set_random_seed : None ):
185
191
""" Scenario where it happens that some sequence gets scheduled in a way
186
192
that it is misaligned with the block boundary (i.e. tkv is not a multiple
187
193
of 64 at the time of prefilling).
@@ -193,7 +199,6 @@ def test_prompts_misaligned_with_tkv_boundaries(
193
199
* 2: len = 41, max tokens = 67, step joining = 0
194
200
* 3: len = 47, max tokens = 9, step joining = 0
195
201
"""
196
-
197
202
seqs_max_tokens = [57 , 67 , 9 ]
198
203
prompts_lengths = [49 , 41 , 47 ]
199
204
steps_add_reqs = [0 , 0 , 0 ] # add all requests in the beginning
@@ -326,7 +331,7 @@ def test_prompts_misaligned_with_tkv_boundaries(
326
331
},
327
332
]
328
333
329
- check_scheduler_inference_steps (
334
+ cb_outputs , prompts = check_scheduler_inference_steps (
330
335
model = model ,
331
336
backend = backend ,
332
337
monkeypatch = monkeypatch ,
@@ -340,12 +345,16 @@ def test_prompts_misaligned_with_tkv_boundaries(
340
345
use_cb = True ,
341
346
)
342
347
348
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
349
+ prompts )
350
+
343
351
344
352
@pytest .mark .cb
345
353
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
346
354
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
347
355
def test_two_sequences_finish_same_time_as_new_arrive (
348
- model : str , backend : str , monkeypatch : pytest .MonkeyPatch ):
356
+ model : str , backend : str , monkeypatch : pytest .MonkeyPatch ,
357
+ set_random_seed ):
349
358
""" 2-cases-in-1: (1) Two sequences finish at the same time and (2) a new
350
359
request arrives when another finishes.
351
360
@@ -356,7 +365,6 @@ def test_two_sequences_finish_same_time_as_new_arrive(
356
365
* 2: len = 30, max tokens = 30, step joining = 0
357
366
* 3: len = 20, max tokens = 10, step joining = 31
358
367
"""
359
-
360
368
seqs_max_tokens = [30 , 30 , 10 ]
361
369
prompts_lengths = [49 , 30 , 20 ]
362
370
steps_add_reqs = [0 , 0 , 31 ]
@@ -466,7 +474,7 @@ def test_two_sequences_finish_same_time_as_new_arrive(
466
474
},
467
475
]
468
476
469
- check_scheduler_inference_steps (
477
+ cb_outputs , prompts = check_scheduler_inference_steps (
470
478
model = model ,
471
479
backend = backend ,
472
480
monkeypatch = monkeypatch ,
@@ -480,12 +488,16 @@ def test_two_sequences_finish_same_time_as_new_arrive(
480
488
use_cb = True ,
481
489
)
482
490
491
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
492
+ prompts )
493
+
483
494
484
495
@pytest .mark .cb
485
496
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
486
497
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
487
498
def test_new_sequence_joins_during_decode (model : str , backend : str ,
488
- monkeypatch : pytest .MonkeyPatch ):
499
+ monkeypatch : pytest .MonkeyPatch ,
500
+ set_random_seed ):
489
501
""" Scenario where a new sequence joins while decoding other sequences
490
502
491
503
Configuration:
@@ -731,7 +743,7 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
731
743
# },
732
744
]
733
745
734
- check_scheduler_inference_steps (
746
+ cb_outputs , prompts = check_scheduler_inference_steps (
735
747
model = model ,
736
748
backend = backend ,
737
749
monkeypatch = monkeypatch ,
@@ -745,12 +757,16 @@ def test_new_sequence_joins_during_decode(model: str, backend: str,
745
757
use_cb = True ,
746
758
)
747
759
760
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
761
+ prompts )
762
+
748
763
749
764
@pytest .mark .cb
750
765
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
751
766
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
752
767
def test_prompt_too_long_for_current_tkv (model : str , backend : str ,
753
- monkeypatch : pytest .MonkeyPatch ):
768
+ monkeypatch : pytest .MonkeyPatch ,
769
+ set_random_seed ):
754
770
""" Scenario where the requested prompt is too long for current tkv value
755
771
756
772
Configuration:
@@ -880,7 +896,7 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
880
896
},
881
897
]
882
898
883
- check_scheduler_inference_steps (
899
+ cb_outputs , prompts = check_scheduler_inference_steps (
884
900
model = model ,
885
901
backend = backend ,
886
902
monkeypatch = monkeypatch ,
@@ -894,13 +910,18 @@ def test_prompt_too_long_for_current_tkv(model: str, backend: str,
894
910
use_cb = True ,
895
911
)
896
912
913
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
914
+ prompts )
915
+
897
916
898
917
@pytest .mark .cb
899
918
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
900
919
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
901
920
def test_requested_tokens_not_fitting_remaining_space (
902
- model : str , backend : str , monkeypatch : pytest .MonkeyPatch ):
903
- """ Scenario where the request goes beyond max_model_len
921
+ model : str , backend : str , monkeypatch : pytest .MonkeyPatch ,
922
+ set_random_seed ):
923
+ """ Scenario where the request goes beyond max_model_len and needs to wait
924
+ for a new batch.
904
925
905
926
Configuration:
906
927
* max_num_seqs: 2
@@ -909,7 +930,6 @@ def test_requested_tokens_not_fitting_remaining_space(
909
930
* 2: len = 49, max tokens = 57, step joining = 0
910
931
* 3: len = 41, max tokens = 80, step joining = 0
911
932
"""
912
-
913
933
seqs_max_tokens = [67 , 57 , 80 ]
914
934
prompts_lengths = [70 , 49 , 41 ]
915
935
steps_add_reqs = [0 , 0 , 0 ]
@@ -1067,7 +1087,7 @@ def test_requested_tokens_not_fitting_remaining_space(
1067
1087
},
1068
1088
]
1069
1089
1070
- check_scheduler_inference_steps (
1090
+ cb_outputs , prompts = check_scheduler_inference_steps (
1071
1091
model = model ,
1072
1092
backend = backend ,
1073
1093
monkeypatch = monkeypatch ,
@@ -1081,12 +1101,16 @@ def test_requested_tokens_not_fitting_remaining_space(
1081
1101
use_cb = True ,
1082
1102
)
1083
1103
1104
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
1105
+ prompts )
1106
+
1084
1107
1085
1108
@pytest .mark .cb
1086
1109
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
1087
1110
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
1088
1111
def test_requests_use_all_available_blocks (model : str , backend : str ,
1089
- monkeypatch : pytest .MonkeyPatch ):
1112
+ monkeypatch : pytest .MonkeyPatch ,
1113
+ set_random_seed ):
1090
1114
""" Scenario where the requests use all of the available blocks
1091
1115
1092
1116
Configuration:
@@ -1098,7 +1122,6 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
1098
1122
* 4: len = 10, max tokens = 3, step joining = 0
1099
1123
* available_blocks: 8
1100
1124
"""
1101
-
1102
1125
seqs_max_tokens = [3 , 3 , 3 , 3 ] # 2 decodes into a new block per sequence
1103
1126
prompts_lengths = [10 , 10 , 10 , 10 ] # 1 block for prefil per sequence
1104
1127
steps_add_reqs = [0 , 0 , 0 , 0 ]
@@ -1201,7 +1224,7 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
1201
1224
},
1202
1225
]
1203
1226
1204
- check_scheduler_inference_steps (
1227
+ cb_outputs , prompts = check_scheduler_inference_steps (
1205
1228
model = model ,
1206
1229
backend = backend ,
1207
1230
monkeypatch = monkeypatch ,
@@ -1215,12 +1238,16 @@ def test_requests_use_all_available_blocks(model: str, backend: str,
1215
1238
use_cb = True ,
1216
1239
)
1217
1240
1241
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
1242
+ prompts )
1243
+
1218
1244
1219
1245
@pytest .mark .cb
1220
1246
@pytest .mark .parametrize ("model" , get_spyre_model_list ())
1221
1247
@pytest .mark .parametrize ("backend" , get_spyre_backend_list ())
1222
1248
def test_requests_use_more_than_available_blocks (
1223
- model : str , backend : str , monkeypatch : pytest .MonkeyPatch ):
1249
+ model : str , backend : str , monkeypatch : pytest .MonkeyPatch ,
1250
+ set_random_seed ):
1224
1251
""" Scenario where some request need to wait because of the number of
1225
1252
available blocks.
1226
1253
@@ -1361,7 +1388,7 @@ def test_requests_use_more_than_available_blocks(
1361
1388
},
1362
1389
]
1363
1390
1364
- check_scheduler_inference_steps (
1391
+ cb_outputs , prompts = check_scheduler_inference_steps (
1365
1392
model = model ,
1366
1393
backend = backend ,
1367
1394
monkeypatch = monkeypatch ,
@@ -1374,3 +1401,6 @@ def test_requests_use_more_than_available_blocks(
1374
1401
available_blocks = available_blocks ,
1375
1402
use_cb = True ,
1376
1403
)
1404
+
1405
+ check_output_against_hf (model , backend , seqs_max_tokens , cb_outputs ,
1406
+ prompts )
0 commit comments