1
1
from functools import cached_property
2
- from typing import Tuple
2
+ from typing import List , Optional , Tuple
3
3
4
4
import torch
5
5
import torch .jit
6
6
7
7
from vllm .model_executor .layers .spec_decode_base_sampler import (
8
- SpecDecodeBaseSampler )
8
+ SpecDecodeStochasticBaseSampler )
9
9
10
10
11
- class RejectionSampler (SpecDecodeBaseSampler ):
11
+ class RejectionSampler (SpecDecodeStochasticBaseSampler ):
12
12
"""Apply modified rejection sampling as described in "Accelerating Large
13
13
Language Model Decoding with Speculative Sampling"
14
14
https://arxiv.org/pdf/2302.01318.pdf.
@@ -36,6 +36,7 @@ def forward(
36
36
bonus_token_ids : torch .Tensor ,
37
37
draft_probs : torch .Tensor ,
38
38
draft_token_ids : torch .Tensor ,
39
+ generators : List [Optional [torch .Generator ]],
39
40
) -> torch .Tensor :
40
41
"""Sample token ids using rejection sampling. This accepts or rejects
41
42
tokens proposed by the draft model using the probability of each token
@@ -82,6 +83,7 @@ def forward(
82
83
target_probs ,
83
84
draft_probs ,
84
85
draft_token_ids ,
86
+ generators ,
85
87
))
86
88
87
89
output_token_ids = self ._create_output (
@@ -94,10 +96,11 @@ def forward(
94
96
return output_token_ids
95
97
96
98
def _batch_modified_rejection_sampling (
97
- self ,
98
- target_probs : torch .Tensor , # [batch_size, k, vocab_size]
99
- draft_probs : torch .Tensor , # [batch_size, k, vocab_size]
100
- draft_token_ids : torch .Tensor , # [batch_size, k]
99
+ self ,
100
+ target_probs : torch .Tensor , # [batch_size, k, vocab_size]
101
+ draft_probs : torch .Tensor , # [batch_size, k, vocab_size]
102
+ draft_token_ids : torch .Tensor , # [batch_size, k]
103
+ generators : List [Optional [torch .Generator ]],
101
104
) -> Tuple [torch .Tensor , torch .Tensor ]:
102
105
"""Perform modified rejection sampling on each sequence.
103
106
@@ -114,22 +117,33 @@ def _batch_modified_rejection_sampling(
114
117
115
118
# shape [batch_size, k]
116
119
accepted = self ._get_accepted (target_probs , draft_probs ,
117
- draft_token_ids )
120
+ draft_token_ids , generators )
118
121
119
122
recovered_probs = self ._get_recovered_probs (
120
123
target_probs , draft_probs ).reshape (batch_size * k , vocab_size )
121
124
125
+ seed_indices , non_seed_indices = self ._split_batch_by_seeded (
126
+ generators , k = k )
127
+
122
128
# NOTE: the recovered_probs are overwritten by this method.
123
- recovered_token_ids = _multinomial (recovered_probs ,
124
- num_samples = 1 ).reshape (
125
- batch_size , k )
129
+ recovered_token_ids = _multinomial (
130
+ recovered_probs ,
131
+ num_samples = 1 ,
132
+ k = k ,
133
+ generators = generators ,
134
+ seed_indices = seed_indices ,
135
+ # this arg is unused when None but torch.jit requires a list
136
+ non_seed_indices = non_seed_indices or [],
137
+ ).reshape (batch_size , k )
138
+
126
139
return accepted , recovered_token_ids
127
140
128
141
def _get_accepted (
129
- self ,
130
- target_probs : torch .Tensor , # [batch_size, k, vocab_size]
131
- draft_probs : torch .Tensor , # [batch_size, k, vocab_size]
132
- draft_token_ids : torch .Tensor , # [batch_size, k]
142
+ self ,
143
+ target_probs : torch .Tensor , # [batch_size, k, vocab_size]
144
+ draft_probs : torch .Tensor , # [batch_size, k, vocab_size]
145
+ draft_token_ids : torch .Tensor , # [batch_size, k]
146
+ generators : List [Optional [torch .Generator ]],
133
147
) -> torch .Tensor :
134
148
r"""Create bool matrix over the proposed draft tokens. If
135
149
True, then a token can be accepted, else it should be
@@ -164,10 +178,28 @@ def _get_accepted(
164
178
selected_target_probs = target_probs [batch_indices , probs_indicies ,
165
179
draft_token_ids ]
166
180
167
- uniform_rand = torch .rand (batch_size ,
168
- k ,
169
- dtype = self .probs_dtype ,
170
- device = target_probs .device )
181
+ seed_indices , non_seed_indices = self ._split_batch_by_seeded (
182
+ generators )
183
+
184
+ if len (seed_indices ) == 0 :
185
+ uniform_rand = torch .rand_like (selected_target_probs )
186
+ else :
187
+ uniform_rand = torch .empty_like (selected_target_probs )
188
+
189
+ for idx in seed_indices :
190
+ uniform_rand [idx , :] = torch .rand (1 ,
191
+ k ,
192
+ dtype = self .probs_dtype ,
193
+ device = target_probs .device ,
194
+ generator = generators [idx ])
195
+
196
+ if non_seed_indices :
197
+ uniform_rand [non_seed_indices , :] = torch .rand (
198
+ len (non_seed_indices ),
199
+ k ,
200
+ dtype = self .probs_dtype ,
201
+ device = target_probs .device )
202
+
171
203
capped_ratio = torch .minimum (
172
204
selected_target_probs / selected_draft_probs ,
173
205
torch .full ((1 , ), 1 , device = target_probs .device ))
@@ -240,6 +272,27 @@ def _smallest_positive_value(self) -> float:
240
272
"""
241
273
return torch .finfo (self .probs_dtype ).tiny
242
274
275
+ # partition batch into indices for which a generator is provided
276
+ # and indicies for which no generator is provided
277
+ @staticmethod
278
+ def _split_batch_by_seeded (
279
+ generators : List [Optional [torch .Generator ]],
280
+ k : int = 1 ,
281
+ ) -> Tuple [List [int ], Optional [List [int ]]]:
282
+
283
+ if all (generator is None for generator in generators ):
284
+ seed_indices : List [int ] = []
285
+ non_seed_indices : Optional [List [int ]] = None
286
+ else :
287
+ seed_indices , non_seed_indices = [], []
288
+ for i , generator in enumerate (generators ):
289
+ if generator is None :
290
+ non_seed_indices .extend (range (k * i , k * (i + 1 )))
291
+ else :
292
+ seed_indices .extend (range (k * i , k * (i + 1 )))
293
+
294
+ return seed_indices , non_seed_indices
295
+
243
296
244
297
# torch.multinomial forces a GPU<->CPU sync.
245
298
# Therefore, we use an optimized implementation instead that skips the sync.
@@ -250,12 +303,25 @@ def _smallest_positive_value(self) -> float:
250
303
def _multinomial (
251
304
probs : torch .Tensor ,
252
305
num_samples : int ,
306
+ k : int ,
307
+ generators : List [Optional [torch .Generator ]],
308
+ seed_indices : List [int ],
309
+ non_seed_indices : List [int ],
253
310
) -> torch .Tensor :
311
+
254
312
if num_samples > 1 :
255
313
# This is equivalent to torch.repeat_interleaved (which also
256
314
# forces a GPU<->CPU sync).
257
315
probs = probs [:, None , :].expand (probs .shape [0 ], num_samples ,
258
316
probs .shape [1 ]).contiguous ().view (
259
317
- 1 , probs .shape [1 ])
260
- q = torch .empty_like (probs ).exponential_ (1.0 )
318
+
319
+ q = torch .empty_like (probs )
320
+ if len (seed_indices ) == 0 :
321
+ q .exponential_ (1.0 )
322
+ else :
323
+ q [non_seed_indices ].exponential_ (1.0 )
324
+ for idx in seed_indices :
325
+ q [idx ].exponential_ (1.0 , generator = generators [idx // k ])
326
+
261
327
return probs .div_ (q ).argmax (dim = 1 ).view (- 1 , num_samples )
0 commit comments