Skip to content

Commit d6e878e

Browse files
LiuXiaoxuanPKUAlvant
authored andcommitted
[SpecDec][Misc] Cleanup, remove bonus token logic. (vllm-project#8701)
Signed-off-by: Alvant <[email protected]>
1 parent 2d0e736 commit d6e878e

File tree

7 files changed

+33
-115
lines changed

7 files changed

+33
-115
lines changed

tests/samplers/test_rejection_sampler.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,13 @@ def mock_causal_accepted_tensor(
4242
@pytest.mark.parametrize(
4343
"which_tokens_accepted",
4444
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
45-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
4645
@pytest.mark.parametrize("device", CUDA_DEVICES)
4746
@pytest.mark.parametrize("use_flashinfer", [True, False])
4847
@torch.inference_mode()
4948
def test_correct_output_format(which_tokens_accepted: str, seed: int,
50-
disable_bonus_tokens: bool, device: str,
51-
use_flashinfer: bool):
49+
device: str, use_flashinfer: bool):
5250
"""Verify the output has correct format given predetermined accepted matrix.
5351
"""
54-
if use_flashinfer and disable_bonus_tokens:
55-
pytest.skip("Flashinfer rejection sampler must enable bonus token.")
56-
5752
set_random_seed(seed)
5853
torch.set_default_device(device)
5954

@@ -88,9 +83,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
8883
size=(batch_size, 1),
8984
dtype=torch.int64)
9085

91-
rejection_sampler = RejectionSampler(
92-
disable_bonus_tokens=disable_bonus_tokens,
93-
use_flashinfer=use_flashinfer)
86+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
9487
rejection_sampler.init_gpu_tensors(device=device)
9588
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
9689
accepted,
@@ -100,10 +93,6 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
10093
)
10194

10295
expected_bonus_token_ids = bonus_token_ids.clone()
103-
# If bonus tokens disabled. Verify they are set to -1.
104-
# See https://github.com/vllm-project/vllm/issues/4212
105-
if disable_bonus_tokens:
106-
expected_bonus_token_ids = expected_bonus_token_ids * 0 - 1
10796

10897
if which_tokens_accepted == "all_tokens_accepted":
10998
# Expect all tokens to be equal to draft tokens.
@@ -143,8 +132,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
143132
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
144133
device: str, use_flashinfer: bool):
145134
torch.set_default_device(device)
146-
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
147-
use_flashinfer=use_flashinfer)
135+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
148136
rejection_sampler.init_gpu_tensors(device=device)
149137

150138
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -177,8 +165,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
177165
frac_seeded: float, n_rep: int, device: str,
178166
use_flashinfer: bool):
179167
torch.set_default_device(device)
180-
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
181-
use_flashinfer=use_flashinfer)
168+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
182169
rejection_sampler.init_gpu_tensors(device=device)
183170

184171
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
@@ -251,8 +238,7 @@ def get_seeded_seqs():
251238
}
252239

253240
for use_flashinfer in [True, False]:
254-
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
255-
use_flashinfer=use_flashinfer)
241+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
256242
rejection_sampler.init_gpu_tensors(device=device)
257243
# We use seeded sequences to ensure the same tokens are accepted
258244
# for both flashinfer and nonflashinfer backends.
@@ -282,8 +268,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
282268
vocab_size = 30_000
283269
torch.set_default_device(device)
284270

285-
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
286-
use_flashinfer=use_flashinfer,
271+
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
287272
strict_mode=True)
288273
rejection_sampler.init_gpu_tensors(device=device)
289274

@@ -359,8 +344,7 @@ def test_rejection_sampling_approximates_target_distribution(
359344
set_random_seed(seed)
360345
helper = _CorrectnessTestHelper(
361346
vocab_size=10,
362-
rejection_sampler=RejectionSampler(disable_bonus_tokens=False,
363-
use_flashinfer=use_flashinfer),
347+
rejection_sampler=RejectionSampler(use_flashinfer=use_flashinfer),
364348
)
365349

366350
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(

tests/samplers/test_typical_acceptance_sampler.py

Lines changed: 20 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,13 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
5555
def get_acceptance_sampler(
5656
posterior_threshold: float = 0.03,
5757
posterior_alpha: float = 0.9,
58-
disable_bonus_tokens: bool = False,
5958
strict_mode: bool = False,
6059
) -> TypicalAcceptanceSampler:
6160
"""
6261
Initializes and returns a TypicalAcceptanceSampler.
6362
"""
6463
return TypicalAcceptanceSampler(posterior_threshold, posterior_alpha,
65-
disable_bonus_tokens, strict_mode)
64+
strict_mode)
6665

6766

6867
@pytest.mark.parametrize("k", list(range(1, 6)))
@@ -154,29 +153,25 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
154153

155154

156155
@pytest.mark.parametrize("seed", list(range(10)))
157-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
158156
@pytest.mark.parametrize("device", CUDA_DEVICES)
159157
@torch.inference_mode()
160158
def test_uniform_target_distribution_accepts_all_tokens(
161-
seed: int, disable_bonus_tokens: bool, device: str):
159+
seed: int, device: str):
162160
"""
163161
Test the TypicalAcceptanceSampler with a uniform target probability
164162
distribution.
165163
166164
This test verifies that when provided with a uniform target probability
167165
distribution, the TypicalAcceptanceSampler accepts all draft tokens. The
168166
entropy of the uniform target distribution being high should lead to all
169-
draft tokens being accepted. The test also ensures that the behavior
170-
regarding bonus tokens is consistent with the `disable_bonus_tokens`
171-
flag.
167+
draft tokens being accepted.
172168
"""
173169
set_random_seed(seed)
174170
k = 3
175171
batch_size = 5
176172
vocab_size = 30_000
177173
torch.set_default_device(device)
178-
typical_acceptance_sampler = get_acceptance_sampler(
179-
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
174+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
180175
typical_acceptance_sampler.init_gpu_tensors(device=device)
181176
target_with_bonus_probs = torch.rand(batch_size,
182177
k + 1,
@@ -200,21 +195,15 @@ def test_uniform_target_distribution_accepts_all_tokens(
200195
# should lead to all draft tokens being accepted. Verify that.
201196
assert output_token_ids.shape[0] == batch_size
202197
assert output_token_ids.shape[1] == (k + 1)
203-
if disable_bonus_tokens:
204-
assert torch.all(output_token_ids[:, -1] == -1)
205-
else:
206-
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
198+
assert torch.all(output_token_ids[:, -1] == bonus_token_ids.squeeze())
207199

208200
assert torch.all(output_token_ids[:, :k] == draft_token_ids)
209201

210202

211203
@pytest.mark.parametrize("seed", list(range(10)))
212-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
213204
@pytest.mark.parametrize("device", CUDA_DEVICES)
214205
@torch.inference_mode()
215-
def test_temperature_zero_target_distribution(seed: int,
216-
disable_bonus_tokens: bool,
217-
device: str):
206+
def test_temperature_zero_target_distribution(seed: int, device: str):
218207
"""
219208
Test the TypicalAcceptanceSampler with a zero-temperature target
220209
probability distribution.
@@ -232,8 +221,7 @@ def test_temperature_zero_target_distribution(seed: int,
232221
vocab_size = 30_000
233222
torch.set_default_device(device)
234223

235-
typical_acceptance_sampler = get_acceptance_sampler(
236-
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
224+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
237225
typical_acceptance_sampler.init_gpu_tensors(device=device)
238226
# Simulate temperature 0 probability distribution for target probabilities
239227
# and create target probabilities such that only 1 token id has
@@ -267,11 +255,9 @@ def test_temperature_zero_target_distribution(seed: int,
267255

268256

269257
@pytest.mark.parametrize("seed", list(range(10)))
270-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
271258
@pytest.mark.parametrize("device", CUDA_DEVICES)
272259
@torch.inference_mode()
273-
def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
274-
device: str):
260+
def test_mixed_target_distribution(seed: int, device: str):
275261
"""
276262
Test the TypicalAcceptanceSampler with a mixed target probability
277263
distribution.
@@ -285,16 +271,13 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
285271
with a probability of 1.0 is accepted, and all other tokens are rejected.
286272
- For sequences with a uniform distribution, all draft tokens are
287273
accepted.
288-
- When `disable_bonus_tokens` is False, the bonus tokens are also accepted
289-
for sequences with a uniform distribution.
290274
"""
291275
set_random_seed(seed)
292276
k = 3
293277
batch_size = 4
294278
vocab_size = 30_000
295279
torch.set_default_device(device)
296-
typical_acceptance_sampler = get_acceptance_sampler(
297-
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
280+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
298281
typical_acceptance_sampler.init_gpu_tensors(device=device)
299282
# For sequences 0 and 2 set the distribution to a temperature
300283
# zero distribution. For sequences 1 and 3 set it to a uniform
@@ -328,21 +311,16 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
328311
0]))
329312
# For sequences 1 and 3 verify that all tokens are accepted since the
330313
# target probability distribution is uniform. In addition verify that
331-
# if disable_bonus_tokens is false then we also accept the bonus tokens.
314+
# we also accept the bonus tokens.
332315
assert torch.all(
333316
output_token_ids[[1, 3], :-1] == draft_token_ids[[1, 3], :])
334-
if disable_bonus_tokens:
335-
assert torch.all(output_token_ids[[1, 3], -1] == -1)
336-
else:
337-
assert torch.all(output_token_ids[[1, 3], -1] != -1)
317+
assert torch.all(output_token_ids[[1, 3], -1] != -1)
338318

339319

340320
@pytest.mark.parametrize("seed", list(range(10)))
341-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
342321
@pytest.mark.parametrize("device", CUDA_DEVICES)
343322
@torch.inference_mode()
344-
def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
345-
device: str):
323+
def test_accept_tokens_partially(seed: int, device: str):
346324
"""
347325
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
348326
tokens should be accepted.
@@ -362,8 +340,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
362340
batch_size = 1
363341
vocab_size = 30_000
364342
torch.set_default_device(device)
365-
typical_acceptance_sampler = get_acceptance_sampler(
366-
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
343+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
367344
typical_acceptance_sampler.init_gpu_tensors(device=device)
368345
# Create a temperature zero target probability distribution and ensure
369346
# all draft token ids correspond to the tokens with 1.0 probability.
@@ -384,10 +361,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
384361
assert output_token_ids.shape[0] == batch_size
385362
assert output_token_ids.shape[1] == (k + 1)
386363
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
387-
if disable_bonus_tokens:
388-
assert torch.all(output_token_ids[:, -1] == -1)
389-
else:
390-
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
364+
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
391365
# Next only keep the first 2 draft tokens same as the zero temperature
392366
# tokens. For the remaining 3 choose some other tokens. In the
393367
# response we will expect the first 2 tokens to be the same as the
@@ -408,12 +382,9 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
408382

409383

410384
@pytest.mark.parametrize("seed", list(range(1)))
411-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
412385
@pytest.mark.parametrize("device", CUDA_DEVICES)
413386
@torch.inference_mode()
414-
def test_accept_tokens_set_non_default_posteriors(seed: int,
415-
disable_bonus_tokens: bool,
416-
device: str):
387+
def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
417388
"""
418389
Test the TypicalAcceptanceSampler with custom posterior thresholds and
419390
alpha values. This test verifies that by modifying the posterior
@@ -425,8 +396,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
425396
batch_size = 1
426397
vocab_size = 30_000
427398
torch.set_default_device(device)
428-
typical_acceptance_sampler = get_acceptance_sampler(
429-
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
399+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
430400
typical_acceptance_sampler.init_gpu_tensors(device=device)
431401
# Simulate temperature 0 probability distribution for target
432402
# probabilities and create target probabilities such that only 1 token
@@ -457,10 +427,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
457427
# now accept even draft tokens with very low probability in the
458428
# target distribution. Simulate and verify the same.
459429
typical_acceptance_sampler = TypicalAcceptanceSampler(
460-
strict_mode=True,
461-
disable_bonus_tokens=disable_bonus_tokens,
462-
posterior_threshold=0.0,
463-
posterior_alpha=0.0)
430+
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
464431
typical_acceptance_sampler.init_gpu_tensors(device=device)
465432
output_token_ids = typical_acceptance_sampler(
466433
target_probs,
@@ -470,18 +437,13 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
470437
assert output_token_ids.shape[0] == batch_size
471438
assert output_token_ids.shape[1] == (k + 1)
472439
assert torch.all(output_token_ids[:, 0:-1] == draft_token_ids)
473-
if disable_bonus_tokens:
474-
assert torch.all(output_token_ids[:, -1] == -1)
475-
else:
476-
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
440+
assert torch.all(output_token_ids[:, -1] == bonus_token_ids)
477441

478442

479443
@pytest.mark.parametrize("seed", list(range(10)))
480-
@pytest.mark.parametrize("disable_bonus_tokens", [True, False])
481444
@pytest.mark.parametrize("device", CUDA_DEVICES)
482445
@torch.inference_mode()
483-
def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
484-
device: str):
446+
def test_replacement_token_ids(seed: int, device: str):
485447
"""
486448
Test the TypicalAcceptanceSampler's method for generating
487449
replacement token IDs.
@@ -497,8 +459,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
497459
batch_size = 5
498460
vocab_size = 30_000
499461
torch.set_default_device(device)
500-
typical_acceptance_sampler = get_acceptance_sampler(
501-
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
462+
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
502463
typical_acceptance_sampler.init_gpu_tensors(device=device)
503464
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
504465
expected_replacement_tokens = -torch.ones(

tests/spec_decode/e2e/test_medusa_correctness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
# speculative model
3232
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
3333

34-
# max. number of speculative tokens: this corresponds to
34+
# max number of speculative tokens: this corresponds to
3535
# num_heads in the config.json of the speculator model.
3636
MAX_SPEC_TOKENS = 5
3737

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
3131
"""
3232

3333
def __init__(self,
34-
disable_bonus_tokens: bool = True,
3534
strict_mode: bool = False,
3635
use_flashinfer: Optional[bool] = None):
3736
"""Create a rejection sampler.
3837
3938
Args:
40-
disable_bonus_tokens: Whether or not to disable the bonus token.
41-
Require when bonus tokens will cause corrupt KV cache for
42-
proposal methods that require KV cache.
4339
strict_mode: Whether or not to perform shape/device/dtype checks
4440
during sampling. This catches correctness issues but adds
4541
nontrivial latency.
@@ -48,17 +44,14 @@ def __init__(self,
4844
None, we will use the default value from the environment variable.
4945
This parameter is only used for testing purposes.
5046
"""
51-
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
52-
strict_mode=strict_mode)
47+
super().__init__(strict_mode=strict_mode)
5348
if use_flashinfer is None:
5449
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
5550
chain_speculative_sampling is not None)
5651
else:
5752
self.use_flashinfer = use_flashinfer
5853

5954
if self.use_flashinfer:
60-
assert not disable_bonus_tokens, \
61-
"flashinfer will enable bonus token by default"
6255
logger.info("Use flashinfer for rejection sampling.")
6356
else:
6457
logger.info("Use pytorch for rejection sampling.")

vllm/model_executor/layers/spec_decode_base_sampler.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
1111
step.
1212
"""
1313

14-
def __init__(self,
15-
disable_bonus_tokens: bool = True,
16-
strict_mode: bool = False):
14+
def __init__(self, strict_mode: bool = False):
1715
"""Base class constructor.
1816
Args:
19-
disable_bonus_tokens: Whether or not to disable the bonus token.
20-
Require when bonus tokens will cause corrupt KV cache for
21-
proposal methods that require KV cache.
2217
strict_mode: Whether or not to perform shape/device/dtype checks
2318
during sampling. This catches correctness issues but adds
2419
nontrivial latency.
2520
"""
2621
super().__init__()
27-
self._disable_bonus_tokens = disable_bonus_tokens
2822
self._strict_mode = strict_mode
2923

3024
# NOTE: A "bonus token" is accepted iff all proposal tokens are
@@ -111,13 +105,6 @@ def _create_output(
111105
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
112106
bonus_token_ids, -1)
113107

114-
# We disable bonus tokens because it causes corrupt KV cache for
115-
# proposal methods that require KV cache. We can fix it by "prefilling"
116-
# the bonus token in the proposer. The following issue tracks the fix.
117-
# https://github.com/vllm-project/vllm/issues/4212
118-
if self._disable_bonus_tokens:
119-
output_with_bonus_tokens[:, -1] = -1
120-
121108
# Fill the recovered token ids.
122109
output.mul_(~after_false_mask).add_(
123110
substitute_token_ids.mul(after_false_mask))

0 commit comments

Comments
 (0)