Skip to content

Commit 68df9c4

Browse files
authored
feat: more sampling operator options (#431)
1. implement the first top-k then top-p sampling to align with vllm and huggingface's behavior vllm-project/vllm#7137 (comment) 2. add options of using a scalar/tensor for top-p/top-k thresholds for all sampling operators.
1 parent daa5566 commit 68df9c4

File tree

10 files changed

+691
-196
lines changed

10 files changed

+691
-196
lines changed

docs/api/python/sampling.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ Kernels for LLM sampling.
1414
top_p_sampling_from_probs
1515
top_k_sampling_from_probs
1616
min_p_sampling_from_probs
17+
top_k_top_p_sampling_from_logits
1718
top_k_top_p_sampling_from_probs
1819
top_p_renorm_prob
1920
top_k_renorm_prob
21+
top_k_mask_logits
2022
chain_speculative_sampling

include/flashinfer/sampling.cuh

Lines changed: 217 additions & 94 deletions
Large diffs are not rendered by default.

python/csrc/flashinfer_ops.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3434
"Top-k and top-p sampling from probabilities");
3535
m.def("top_k_renorm_prob", &top_k_renorm_prob, "Renormalize probabilities by top-k mask");
3636
m.def("top_p_renorm_prob", &top_p_renorm_prob, "Renormalize probabilities by top-p mask");
37+
m.def("top_k_mask_logits", &top_k_mask_logits, "Mask logits by top-k mask");
3738
m.def("chain_speculative_sampling", &chain_speculative_sampling,
3839
"Speculative sampling from sequence of probabilities");
3940
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");

python/csrc/flashinfer_ops.h

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,25 +39,33 @@ torch::Tensor sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_sam
3939
bool deterministic);
4040

4141
std::vector<torch::Tensor> top_p_sampling_from_probs(torch::Tensor probs,
42-
torch::Tensor uniform_samples, double top_p,
43-
bool deterministic);
42+
torch::Tensor uniform_samples,
43+
std::optional<torch::Tensor> maybe_top_p_arr,
44+
double top_p_val, bool deterministic);
4445

4546
std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs,
4647
torch::Tensor uniform_samples,
47-
unsigned int top_k, bool deterministic);
48+
std::optional<torch::Tensor> maybe_top_k_arr,
49+
unsigned int top_k_val, bool deterministic);
4850

4951
std::vector<torch::Tensor> min_p_sampling_from_probs(torch::Tensor probs,
5052
torch::Tensor uniform_samples,
51-
torch::Tensor min_p, bool deterministic);
53+
std::optional<torch::Tensor> maybe_min_p_arr,
54+
double min_p_val, bool deterministic);
55+
56+
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
57+
torch::Tensor probs, torch::Tensor uniform_samples,
58+
std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
59+
std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val, bool deterministic);
5260

53-
std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(torch::Tensor probs,
54-
torch::Tensor uniform_samples,
55-
torch::Tensor top_k, torch::Tensor top_p,
56-
bool deterministic);
61+
torch::Tensor top_p_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_p_arr,
62+
double top_p_val, double eps);
5763

58-
torch::Tensor top_p_renorm_prob(torch::Tensor probs, double top_p, double eps);
64+
torch::Tensor top_k_renorm_prob(torch::Tensor probs, std::optional<torch::Tensor> maybe_top_k_arr,
65+
unsigned int top_k_val, double eps);
5966

60-
torch::Tensor top_k_renorm_prob(torch::Tensor probs, unsigned int top_k, double eps);
67+
torch::Tensor top_k_mask_logits(torch::Tensor logits, std::optional<torch::Tensor> maybe_top_k_arr,
68+
unsigned int top_k_val, double eps);
6169

6270
torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tensor draft_token_ids,
6371
torch::Tensor uniform_samples, torch::Tensor target_probs,

python/csrc/sampling.cu

Lines changed: 128 additions & 38 deletions
Large diffs are not rendered by default.

python/flashinfer/__init__.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,46 @@
1414
limitations under the License.
1515
"""
1616

17-
from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
18-
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
19-
merge_state, merge_state_in_place, merge_states)
20-
from .decode import (BatchDecodeWithPagedKVCacheWrapper,
21-
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
22-
single_decode_with_kv_cache)
17+
from .cascade import (
18+
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
19+
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
20+
merge_state,
21+
merge_state_in_place,
22+
merge_states,
23+
)
24+
from .decode import (
25+
BatchDecodeWithPagedKVCacheWrapper,
26+
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
27+
single_decode_with_kv_cache,
28+
)
2329
from .group_gemm import SegmentGEMMWrapper
2430
from .norm import fused_add_rmsnorm, rmsnorm
2531
from .page import append_paged_kv_cache
26-
from .prefill import (BatchPrefillWithPagedKVCacheWrapper,
27-
BatchPrefillWithRaggedKVCacheWrapper,
28-
single_prefill_with_kv_cache,
29-
single_prefill_with_kv_cache_return_lse)
32+
from .prefill import (
33+
BatchPrefillWithPagedKVCacheWrapper,
34+
BatchPrefillWithRaggedKVCacheWrapper,
35+
single_prefill_with_kv_cache,
36+
single_prefill_with_kv_cache_return_lse,
37+
)
3038
from .quantization import packbits, segment_packbits
31-
from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope,
32-
apply_rope_inplace)
33-
from .sampling import (chain_speculative_sampling, sampling_from_probs,
34-
top_k_renorm_prob, top_k_sampling_from_probs,
35-
top_k_top_p_sampling_from_probs, top_p_renorm_prob,
36-
top_p_sampling_from_probs)
39+
from .rope import (
40+
apply_llama31_rope,
41+
apply_llama31_rope_inplace,
42+
apply_rope,
43+
apply_rope_inplace,
44+
)
45+
from .sampling import (
46+
chain_speculative_sampling,
47+
sampling_from_probs,
48+
top_k_renorm_prob,
49+
top_k_mask_logits,
50+
top_k_sampling_from_probs,
51+
top_k_top_p_sampling_from_probs,
52+
top_k_top_p_sampling_from_logits,
53+
top_p_renorm_prob,
54+
top_p_sampling_from_probs,
55+
min_p_sampling_from_probs,
56+
)
3757
from .sparse import BlockSparseAttentionWrapper
3858

3959
try:

0 commit comments

Comments
 (0)