|
21 | 21 | "num_speculative_tokens": 3,
|
22 | 22 | }])
|
23 | 23 | @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
24 |
| -@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) |
| 24 | +@pytest.mark.parametrize("baseline_llm_kwargs", [{"seed": 1}]) |
| 25 | +@pytest.mark.parametrize("test_llm_kwargs", [{"seed": 5}]) |
25 | 26 | @pytest.mark.parametrize("batch_size", [1, 8, 32])
|
26 | 27 | @pytest.mark.parametrize("temperature", [0.1, 1.0])
|
27 | 28 | @pytest.mark.parametrize(
|
|
30 | 31 | # Use smaller output len for fast test.
|
31 | 32 | 10,
|
32 | 33 | ])
|
33 |
| -@pytest.mark.parametrize("seed", [1]) |
34 |
| -def test_seeded_consistency(baseline_llm_generator, batch_size: int, |
35 |
| - temperature: float, output_len: int): |
| 34 | +@pytest.mark.parametrize("seed", [None]) |
| 35 | +def test_seeded_consistency(baseline_llm_generator, test_llm_generator, |
| 36 | + batch_size: int, temperature: float, |
| 37 | + output_len: int): |
36 | 38 | """Verify outputs are consistent across multiple runs with same seed
|
37 | 39 | """
|
38 | 40 | run_equality_correctness_test(baseline_llm_generator,
|
39 |
| - baseline_llm_generator, |
| 41 | + test_llm_generator, |
40 | 42 | batch_size,
|
41 | 43 | max_output_len=output_len,
|
42 | 44 | temperature=temperature,
|
43 | 45 | seeded=True,
|
44 | 46 | force_output_len=True)
|
| 47 | + |
| 48 | + # Ensure this same test does fail if we _don't_ include per-request seeds |
| 49 | + with pytest.raises(AssertionError): |
| 50 | + run_equality_correctness_test(baseline_llm_generator, |
| 51 | + test_llm_generator, |
| 52 | + batch_size, |
| 53 | + max_output_len=output_len, |
| 54 | + temperature=temperature, |
| 55 | + seeded=False, |
| 56 | + force_output_len=True) |
0 commit comments