Skip to content

Commit 76442c2

Browse files
tdoublepnjhill
authored andcommitted
[Bugfix] Make spec. decode respect per-request seed. (vllm-project#6034)
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 6011d52 commit 76442c2

File tree

8 files changed

+293
-46
lines changed

8 files changed

+293
-46
lines changed

tests/samplers/test_rejection_sampler.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,54 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
150150
high=vocab_size,
151151
size=(batch_size, k),
152152
dtype=torch.int64)
153+
generators = [None] * batch_size
153154

154155
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
155-
draft_token_ids)
156+
draft_token_ids, generators)
157+
158+
159+
@pytest.mark.parametrize("frac_seeded", [0.0, 0.25, 0.5, 1.0])
160+
@pytest.mark.parametrize("k", [1, 3, 6])
161+
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
162+
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
163+
@pytest.mark.parametrize("n_rep", [100])
164+
@pytest.mark.parametrize("device", CUDA_DEVICES)
165+
@torch.inference_mode()
166+
def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
167+
frac_seeded: float, n_rep: int,
168+
device: str):
169+
torch.set_default_device(device)
170+
rejection_sampler = RejectionSampler()
171+
rejection_sampler.init_gpu_tensors(rank=0)
172+
173+
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
174+
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
175+
bonus_token_ids = torch.randint(low=0,
176+
high=vocab_size,
177+
size=(batch_size, 1),
178+
dtype=torch.int64)
179+
draft_token_ids = torch.randint(low=0,
180+
high=vocab_size,
181+
size=(batch_size, k),
182+
dtype=torch.int64)
183+
184+
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
185+
186+
results = []
187+
for _ in range(n_rep):
188+
generators = [
189+
torch.Generator(
190+
device=device).manual_seed(i) if seeded_mask[i] else None
191+
for i in range(batch_size)
192+
]
193+
results.append(
194+
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
195+
draft_token_ids, generators))
196+
197+
for i in range(batch_size):
198+
if seeded_mask[i]:
199+
for j in range(1, n_rep):
200+
assert torch.equal(results[j][i], results[0][i])
156201

157202

158203
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@@ -197,10 +242,11 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
197242
raise AssertionError()
198243

199244
oob_token_ids[0][0] = rogue_token_id
245+
generators = [None] * batch_size
200246

201247
with pytest.raises(AssertionError):
202248
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
203-
draft_token_ids)
249+
draft_token_ids, generators)
204250

205251

206252
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@@ -371,11 +417,15 @@ def _estimate_rejection_sampling_pdf(
371417
dtype=torch.int64,
372418
device="cuda").repeat(num_samples, 1)
373419

420+
# unseeded
421+
generators = [None]
422+
374423
# Get output tokens via rejection sampling.
375424
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
376425
bonus_token_ids.to("cuda"),
377426
draft_probs.to("cuda"),
378-
draft_token_ids.to("cuda"))
427+
draft_token_ids.to("cuda"),
428+
generators)
379429

380430
# Remove bonus tokens
381431
output_token_ids = output_token_ids[:, :-1].flatten()

tests/spec_decode/e2e/conftest.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from itertools import cycle
3-
from typing import Dict, List, Optional, Tuple, Union
3+
from typing import Dict, List, Optional, Sequence, Tuple, Union
44

55
import pytest
66
import ray
@@ -128,7 +128,9 @@ async def get_output(prompt, sampling_param) -> RequestOutput:
128128
try:
129129
for i in range(num_requests):
130130
prompt = prompts[i] if prompts is not None else None
131-
res = asyncio.run(get_output(prompt, sampling_params))
131+
params = sampling_params[i] if isinstance(
132+
sampling_params, Sequence) else sampling_params
133+
res = asyncio.run(get_output(prompt, params))
132134
outputs.append(res)
133135
finally:
134136
ray.shutdown()
@@ -267,7 +269,31 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
267269
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
268270
the same when temperature is zero.
269271
"""
270-
temperature = 0.0
272+
273+
run_equality_correctness_test(baseline_llm_generator,
274+
test_llm_generator,
275+
batch_size,
276+
max_output_len,
277+
force_output_len,
278+
temperature=0.0,
279+
seeded=False,
280+
print_tokens=print_tokens,
281+
ensure_all_accepted=ensure_all_accepted)
282+
283+
284+
def run_equality_correctness_test(baseline_llm_generator,
285+
test_llm_generator,
286+
batch_size,
287+
max_output_len,
288+
force_output_len: bool,
289+
temperature: float,
290+
seeded: bool,
291+
print_tokens: bool = False,
292+
ensure_all_accepted: bool = False):
293+
"""Helper method that compares the outputs of both the baseline LLM and
294+
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
295+
the same when temperature is zero (or when temperature is > 0 and seeded).
296+
"""
271297

272298
prompts = [
273299
"Hello, my name is",
@@ -286,11 +312,21 @@ def run_greedy_equality_correctness_test(baseline_llm_generator,
286312
# sampling params to ignore eos token.
287313
ignore_eos = force_output_len
288314

289-
sampling_params = SamplingParams(
290-
max_tokens=max_output_len,
291-
ignore_eos=ignore_eos,
292-
temperature=temperature,
293-
)
315+
if seeded:
316+
sampling_params = [
317+
SamplingParams(
318+
max_tokens=max_output_len,
319+
ignore_eos=ignore_eos,
320+
temperature=temperature,
321+
seed=i,
322+
) for i in range(len(prompts))
323+
]
324+
else:
325+
sampling_params = SamplingParams(
326+
max_tokens=max_output_len,
327+
ignore_eos=ignore_eos,
328+
temperature=temperature,
329+
)
294330

295331
(spec_batch_tokens, spec_batch_token_ids,
296332
acceptance_rate) = get_output_from_llm_generator(test_llm_generator,

tests/spec_decode/e2e/test_seed.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import pytest
2+
3+
from .conftest import run_equality_correctness_test
4+
5+
6+
@pytest.mark.parametrize(
7+
"common_llm_kwargs",
8+
[{
9+
"model": "JackFram/llama-68m",
10+
11+
# Skip cuda graph recording for fast test.
12+
"enforce_eager": True,
13+
14+
# Required for spec decode.
15+
"use_v2_block_manager": True,
16+
17+
# speculative model
18+
"speculative_model": "JackFram/llama-160m",
19+
20+
# num speculative tokens
21+
"num_speculative_tokens": 3,
22+
}])
23+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
24+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
25+
@pytest.mark.parametrize("batch_size", [1, 8, 32])
26+
@pytest.mark.parametrize("temperature", [0.1, 1.0])
27+
@pytest.mark.parametrize(
28+
"output_len",
29+
[
30+
# Use smaller output len for fast test.
31+
10,
32+
])
33+
@pytest.mark.parametrize("seed", [1])
34+
def test_seeded_consistency(baseline_llm_generator, batch_size: int,
35+
temperature: float, output_len: int):
36+
"""Verify outputs are consistent across multiple runs with same seed
37+
"""
38+
run_equality_correctness_test(baseline_llm_generator,
39+
baseline_llm_generator,
40+
batch_size,
41+
max_output_len=output_len,
42+
temperature=temperature,
43+
seeded=True,
44+
force_output_len=True)

vllm/model_executor/layers/rejection_sampler.py

Lines changed: 86 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from functools import cached_property
2-
from typing import Tuple
2+
from typing import List, Optional, Tuple
33

44
import torch
55
import torch.jit
66

77
from vllm.model_executor.layers.spec_decode_base_sampler import (
8-
SpecDecodeBaseSampler)
8+
SpecDecodeStochasticBaseSampler)
99

1010

11-
class RejectionSampler(SpecDecodeBaseSampler):
11+
class RejectionSampler(SpecDecodeStochasticBaseSampler):
1212
"""Apply modified rejection sampling as described in "Accelerating Large
1313
Language Model Decoding with Speculative Sampling"
1414
https://arxiv.org/pdf/2302.01318.pdf.
@@ -36,6 +36,7 @@ def forward(
3636
bonus_token_ids: torch.Tensor,
3737
draft_probs: torch.Tensor,
3838
draft_token_ids: torch.Tensor,
39+
generators: List[Optional[torch.Generator]],
3940
) -> torch.Tensor:
4041
"""Sample token ids using rejection sampling. This accepts or rejects
4142
tokens proposed by the draft model using the probability of each token
@@ -82,6 +83,7 @@ def forward(
8283
target_probs,
8384
draft_probs,
8485
draft_token_ids,
86+
generators,
8587
))
8688

8789
output_token_ids = self._create_output(
@@ -94,10 +96,11 @@ def forward(
9496
return output_token_ids
9597

9698
def _batch_modified_rejection_sampling(
97-
self,
98-
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
99-
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
100-
draft_token_ids: torch.Tensor, # [batch_size, k]
99+
self,
100+
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
101+
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
102+
draft_token_ids: torch.Tensor, # [batch_size, k]
103+
generators: List[Optional[torch.Generator]],
101104
) -> Tuple[torch.Tensor, torch.Tensor]:
102105
"""Perform modified rejection sampling on each sequence.
103106
@@ -114,22 +117,33 @@ def _batch_modified_rejection_sampling(
114117

115118
# shape [batch_size, k]
116119
accepted = self._get_accepted(target_probs, draft_probs,
117-
draft_token_ids)
120+
draft_token_ids, generators)
118121

119122
recovered_probs = self._get_recovered_probs(
120123
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
121124

125+
seed_indices, non_seed_indices = self._split_batch_by_seeded(
126+
generators, k=k)
127+
122128
# NOTE: the recovered_probs are overwritten by this method.
123-
recovered_token_ids = _multinomial(recovered_probs,
124-
num_samples=1).reshape(
125-
batch_size, k)
129+
recovered_token_ids = _multinomial(
130+
recovered_probs,
131+
num_samples=1,
132+
k=k,
133+
generators=generators,
134+
seed_indices=seed_indices,
135+
# this arg is unused when None but torch.jit requires a list
136+
non_seed_indices=non_seed_indices or [],
137+
).reshape(batch_size, k)
138+
126139
return accepted, recovered_token_ids
127140

128141
def _get_accepted(
129-
self,
130-
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
131-
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
132-
draft_token_ids: torch.Tensor, # [batch_size, k]
142+
self,
143+
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
144+
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
145+
draft_token_ids: torch.Tensor, # [batch_size, k]
146+
generators: List[Optional[torch.Generator]],
133147
) -> torch.Tensor:
134148
r"""Create bool matrix over the proposed draft tokens. If
135149
True, then a token can be accepted, else it should be
@@ -164,10 +178,28 @@ def _get_accepted(
164178
selected_target_probs = target_probs[batch_indices, probs_indicies,
165179
draft_token_ids]
166180

167-
uniform_rand = torch.rand(batch_size,
168-
k,
169-
dtype=self.probs_dtype,
170-
device=target_probs.device)
181+
seed_indices, non_seed_indices = self._split_batch_by_seeded(
182+
generators)
183+
184+
if len(seed_indices) == 0:
185+
uniform_rand = torch.rand_like(selected_target_probs)
186+
else:
187+
uniform_rand = torch.empty_like(selected_target_probs)
188+
189+
for idx in seed_indices:
190+
uniform_rand[idx, :] = torch.rand(1,
191+
k,
192+
dtype=self.probs_dtype,
193+
device=target_probs.device,
194+
generator=generators[idx])
195+
196+
if non_seed_indices:
197+
uniform_rand[non_seed_indices, :] = torch.rand(
198+
len(non_seed_indices),
199+
k,
200+
dtype=self.probs_dtype,
201+
device=target_probs.device)
202+
171203
capped_ratio = torch.minimum(
172204
selected_target_probs / selected_draft_probs,
173205
torch.full((1, ), 1, device=target_probs.device))
@@ -240,6 +272,27 @@ def _smallest_positive_value(self) -> float:
240272
"""
241273
return torch.finfo(self.probs_dtype).tiny
242274

275+
# partition batch into indices for which a generator is provided
276+
# and indicies for which no generator is provided
277+
@staticmethod
278+
def _split_batch_by_seeded(
279+
generators: List[Optional[torch.Generator]],
280+
k: int = 1,
281+
) -> Tuple[List[int], Optional[List[int]]]:
282+
283+
if all(generator is None for generator in generators):
284+
seed_indices: List[int] = []
285+
non_seed_indices: Optional[List[int]] = None
286+
else:
287+
seed_indices, non_seed_indices = [], []
288+
for i, generator in enumerate(generators):
289+
if generator is None:
290+
non_seed_indices.extend(range(k * i, k * (i + 1)))
291+
else:
292+
seed_indices.extend(range(k * i, k * (i + 1)))
293+
294+
return seed_indices, non_seed_indices
295+
243296

244297
# torch.multinomial forces a GPU<->CPU sync.
245298
# Therefore, we use an optimized implementation instead that skips the sync.
@@ -250,12 +303,25 @@ def _smallest_positive_value(self) -> float:
250303
def _multinomial(
251304
probs: torch.Tensor,
252305
num_samples: int,
306+
k: int,
307+
generators: List[Optional[torch.Generator]],
308+
seed_indices: List[int],
309+
non_seed_indices: List[int],
253310
) -> torch.Tensor:
311+
254312
if num_samples > 1:
255313
# This is equivalent to torch.repeat_interleaved (which also
256314
# forces a GPU<->CPU sync).
257315
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
258316
probs.shape[1]).contiguous().view(
259317
-1, probs.shape[1])
260-
q = torch.empty_like(probs).exponential_(1.0)
318+
319+
q = torch.empty_like(probs)
320+
if len(seed_indices) == 0:
321+
q.exponential_(1.0)
322+
else:
323+
q[non_seed_indices].exponential_(1.0)
324+
for idx in seed_indices:
325+
q[idx].exponential_(1.0, generator=generators[idx // k])
326+
261327
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

0 commit comments

Comments
 (0)