-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[Generative Score API] Multi-Item scoring with custom attention mask. #10979
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general, thanks for the efforts! Two general comments: (1) is the radix cache for prefix caching case properly handled especially cache eviction in radix tree? (2) can we add a small test case for three cases: single req, batched req, prefix caching cases. That should be enough to cover most cases, when hooking with the flashinfer kernels we added previously. Btw, make sure we switch to flashinfer backend always and disable chunked prefill when using this mode. Thank you!
7cdbdc3
to
70ffe6d
Compare
|
be64aa6
to
aabffc4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
🚀 Motivation
✍️ Example
Given:
Performance Impact:
Flashinfer Custom Mask
🔧 Modifications
--multi-item-scoring-delimiter <int>
(token id) with validation that requires:--disable-radix-cache
--chunked-prefill-size -1
python/sglang/srt/layers/attention/flashinfer_backend.py
):MultiItemScoringParams
and compute per-prompt tensors:prefix_len_ptr
(uint32): prefix length before the first delimitertoken_pos_in_items_ptr
(uint16): concatenated relative positions with 0 at each delimitertoken_pos_in_items_len
(int): padded length for batchmax_item_len_ptr
(uint16): max item length per promptprefix_len_ptr
,token_pos_in_items_ptr
,token_pos_in_items_len
,max_item_len_ptr
toBatchPrefillWithPagedKVCacheWrapper.begin_forward
.Accuracy Tests
Benchmarking and Profiling
🧪 Benchmark Comparison: Qwen3-0.6B on H100 (CUDA 12.8)
Setup:
Server Start:
Benchmark Script:
Checklist