Skip to content

Commit c36d4d2

Browse files
cadedanielalexeykondrat
authored andcommitted
[Speculative decoding 7/9] Speculative decoding end-to-end correctness tests. (vllm-project#3951)
1 parent 44aa867 commit c36d4d2

22 files changed

+1164
-175
lines changed

tests/samplers/test_rejection_sampler.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,16 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
9191
bonus_token_ids,
9292
)
9393

94+
# Bonus tokens are currently disabled. Verify they're set to -1.
95+
# See https://github.com/vllm-project/vllm/issues/4212
96+
expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1
97+
9498
if which_tokens_accepted == "all_tokens_accepted":
9599
# Expect all tokens to be equal to draft tokens.
96100
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
97101

98102
# Expect all bonus tokens to be included.
99-
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
103+
assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)
100104
elif which_tokens_accepted == "no_tokens_accepted":
101105
# Expect first token to be equal to recovered tokens.
102106
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
@@ -106,7 +110,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
106110
torch.ones_like(output_token_ids[:, 1:]) * -1)
107111
elif which_tokens_accepted == "some_tokens_accepted":
108112
recovered_plus_bonus = torch.cat(
109-
(recovered_token_ids, bonus_token_ids), dim=-1)
113+
(recovered_token_ids, expected_bonus_token_ids), dim=-1)
110114
# Assert first rejected token is a recovered token or bonus token.
111115
assert torch.equal(
112116
recovered_plus_bonus[torch.arange(0, batch_size),

tests/samplers/test_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,8 @@ def test_sampler_top_k_top_p(seed: int, device: str):
636636
def mock_sample(probs, *args, **kwargs):
637637
nonlocal sample_probs
638638
sample_probs = probs
639-
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
639+
return ([[prob.topk(1, dim=-1).indices.tolist(), [0]]
640+
for prob in probs], None)
640641

641642
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
642643
sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

tests/spec_decode/e2e/__init__.py

Whitespace-only changes.

tests/spec_decode/e2e/conftest.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import List, Tuple
2+
13
import pytest
24

35
from tests.conftest import cleanup
@@ -6,28 +8,34 @@
68

79

810
@pytest.fixture
9-
def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
10-
baseline_llm_kwargs, seed):
11-
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
11+
def baseline_llm_generator(request, common_llm_kwargs,
12+
per_test_common_llm_kwargs, baseline_llm_kwargs,
13+
seed):
14+
return create_llm_generator("baseline", request, common_llm_kwargs,
15+
per_test_common_llm_kwargs,
1216
baseline_llm_kwargs, seed)
1317

1418

1519
@pytest.fixture
16-
def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
20+
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
1721
test_llm_kwargs, seed):
18-
return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
19-
test_llm_kwargs, seed)
22+
return create_llm_generator("test", request, common_llm_kwargs,
23+
per_test_common_llm_kwargs, test_llm_kwargs,
24+
seed)
2025

2126

22-
def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
23-
distinct_llm_kwargs, seed):
27+
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
28+
per_test_common_llm_kwargs, distinct_llm_kwargs,
29+
seed):
2430
kwargs = {
2531
**common_llm_kwargs,
2632
**per_test_common_llm_kwargs,
2733
**distinct_llm_kwargs,
2834
}
35+
test_name = request.node.name
2936

3037
def generator_inner():
38+
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
3139
llm = LLM(**kwargs)
3240

3341
set_random_seed(seed)
@@ -36,6 +44,23 @@ def generator_inner():
3644
del llm
3745
cleanup()
3846

39-
for llm in generator_inner():
40-
yield llm
47+
def generator_outer():
48+
for llm in generator_inner():
49+
yield llm
50+
del llm
51+
52+
return generator_outer
53+
54+
55+
def get_output_from_llm_generator(
56+
llm_generator, prompts,
57+
sampling_params) -> Tuple[List[str], List[List[int]]]:
58+
tokens = []
59+
token_ids = []
60+
for llm in llm_generator():
61+
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
62+
token_ids = [output.outputs[0].token_ids for output in outputs]
63+
tokens = [output.outputs[0].text for output in outputs]
4164
del llm
65+
66+
return tokens, token_ids
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import pytest
2+
3+
from vllm import SamplingParams
4+
5+
from .conftest import get_output_from_llm_generator
6+
7+
8+
@pytest.mark.parametrize(
9+
"common_llm_kwargs",
10+
[{
11+
"model": "JackFram/llama-68m",
12+
"speculative_model": "JackFram/llama-68m",
13+
"num_speculative_tokens": 5,
14+
15+
# Required for spec decode.
16+
"use_v2_block_manager": True
17+
}])
18+
@pytest.mark.parametrize(
19+
"per_test_common_llm_kwargs",
20+
[
21+
{
22+
# Expect failure as spec decode not supported by
23+
# Ray backend.
24+
"worker_use_ray": True,
25+
},
26+
])
27+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
28+
@pytest.mark.parametrize("seed", [1])
29+
def test_spec_decode_xfail_ray(test_llm_generator):
30+
"""Verify that speculative decoding with Ray fails.
31+
"""
32+
output_len = 128
33+
temperature = 0.0
34+
35+
prompts = [
36+
"Hello, my name is",
37+
]
38+
39+
sampling_params = SamplingParams(
40+
max_tokens=output_len,
41+
ignore_eos=True,
42+
temperature=temperature,
43+
)
44+
45+
with pytest.raises(AssertionError,
46+
match="Speculative decoding not yet supported for "):
47+
get_output_from_llm_generator(test_llm_generator, prompts,
48+
sampling_params)
49+
50+
51+
@pytest.mark.parametrize(
52+
"common_llm_kwargs",
53+
[{
54+
"model": "JackFram/llama-68m",
55+
"speculative_model": "JackFram/llama-68m",
56+
"num_speculative_tokens": 5,
57+
58+
# Required for spec decode.
59+
"use_v2_block_manager": True
60+
}])
61+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
62+
{
63+
"enable_chunked_prefill": True,
64+
},
65+
])
66+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
67+
@pytest.mark.parametrize("seed", [1])
68+
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
69+
"""Verify that speculative decoding with chunked prefill fails.
70+
"""
71+
output_len = 128
72+
temperature = 0.0
73+
74+
prompts = [
75+
"Hello, my name is",
76+
]
77+
78+
sampling_params = SamplingParams(
79+
max_tokens=output_len,
80+
ignore_eos=True,
81+
temperature=temperature,
82+
)
83+
84+
with pytest.raises(ValueError,
85+
match="Speculative decoding and chunked prefill"):
86+
get_output_from_llm_generator(test_llm_generator, prompts,
87+
sampling_params)
88+
89+
90+
@pytest.mark.parametrize(
91+
"common_llm_kwargs",
92+
[{
93+
"model": "meta-llama/Llama-2-7b-chat-hf",
94+
"speculative_model": "JackFram/llama-68m",
95+
"num_speculative_tokens": 5,
96+
97+
# Required for spec decode.
98+
"use_v2_block_manager": True
99+
}])
100+
@pytest.mark.parametrize(
101+
"per_test_common_llm_kwargs",
102+
[
103+
{
104+
# Speculative max model len > overridden max model len should raise.
105+
"max_model_len": 128,
106+
"speculative_max_model_len": 129,
107+
},
108+
{
109+
# Speculative max model len > draft max model len should raise.
110+
# https://huggingface.co/JackFram/llama-68m/blob/3b606af5198a0b26762d589a3ee3d26ee6fa6c85/config.json#L12
111+
"speculative_max_model_len": 2048 + 1,
112+
},
113+
{
114+
# Speculative max model len > target max model len should raise.
115+
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
116+
"speculative_max_model_len": 4096 + 1,
117+
},
118+
])
119+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
120+
@pytest.mark.parametrize("seed", [1])
121+
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
122+
"""Verify that speculative decoding validates speculative_max_model_len.
123+
"""
124+
output_len = 128
125+
temperature = 0.0
126+
127+
prompts = [
128+
"Hello, my name is",
129+
]
130+
131+
sampling_params = SamplingParams(
132+
max_tokens=output_len,
133+
ignore_eos=True,
134+
temperature=temperature,
135+
)
136+
137+
with pytest.raises(ValueError, match="cannot be larger than"):
138+
get_output_from_llm_generator(test_llm_generator, prompts,
139+
sampling_params)
140+
141+
142+
@pytest.mark.parametrize("common_llm_kwargs", [{
143+
"model": "JackFram/llama-68m",
144+
"speculative_model": "JackFram/llama-68m",
145+
"num_speculative_tokens": 5,
146+
}])
147+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
148+
@pytest.mark.parametrize("test_llm_kwargs", [{}])
149+
@pytest.mark.parametrize("seed", [1])
150+
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
151+
"""Verify that speculative decoding with block manager v1 fails.
152+
"""
153+
output_len = 128
154+
temperature = 0.0
155+
156+
prompts = [
157+
"Hello, my name is",
158+
]
159+
160+
sampling_params = SamplingParams(
161+
max_tokens=output_len,
162+
ignore_eos=True,
163+
temperature=temperature,
164+
)
165+
166+
with pytest.raises(ValueError,
167+
match="Speculative decoding requires usage of the V2"):
168+
get_output_from_llm_generator(test_llm_generator, prompts,
169+
sampling_params)

0 commit comments

Comments
 (0)