Skip to content

Commit 8aea36d

Browse files
njhillAlvant
authored andcommitted
[Bugfix] Fix speculative decode seeded test (vllm-project#6743)
Signed-off-by: Alvant <[email protected]>
1 parent 9a4081b commit 8aea36d

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

tests/spec_decode/e2e/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def generator_inner():
191191
and llm.llm_engine.log_stats):
192192
for sate_logger in llm.llm_engine.stat_loggers.values():
193193
sate_logger.local_interval = 0
194-
set_random_seed(seed)
194+
if seed is not None:
195+
set_random_seed(seed)
195196

196197
yield llm
197198
del llm

tests/spec_decode/e2e/test_seed.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
"num_speculative_tokens": 3,
2222
}])
2323
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
24-
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
24+
@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}])
25+
@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}])
2526
@pytest.mark.parametrize("batch_size", [1, 8, 32])
2627
@pytest.mark.parametrize("temperature", [0.1, 1.0])
2728
@pytest.mark.parametrize(
@@ -30,15 +31,26 @@
3031
# Use smaller output len for fast test.
3132
10,
3233
])
33-
@pytest.mark.parametrize("seed", [1])
34-
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
35-
temperature: float, output_len: int):
34+
@pytest.mark.parametrize("seed", [None])
35+
def test_seeded_consistency(baseline_llm_generator, test_llm_generator,
36+
batch_size: int, temperature: float,
37+
output_len: int):
3638
"""Verify outputs are consistent across multiple runs with same seed
3739
"""
3840
run_equality_correctness_test(baseline_llm_generator,
39-
baseline_llm_generator,
41+
test_llm_generator,
4042
batch_size,
4143
max_output_len=output_len,
4244
temperature=temperature,
4345
seeded=True,
4446
force_output_len=True)
47+
48+
# Ensure this same test does fail if we _don't_ include per-request seeds
49+
with pytest.raises(AssertionError):
50+
run_equality_correctness_test(baseline_llm_generator,
51+
test_llm_generator,
52+
batch_size,
53+
max_output_len=output_len,
54+
temperature=temperature,
55+
seeded=False,
56+
force_output_len=True)

0 commit comments

Comments
 (0)