Skip to content

Conversation

sundar24295s
Copy link
Collaborator

@sundar24295s sundar24295s commented Sep 27, 2025

🚀 Motivation

  • To Score multiple items against a single query in one forward pass by preventing items from attending to each other.
  • We leverage FlashInfer’s custom item-aware masking (see FlashInfer PR: custom item boundaries and masks), passing item-boundary metadata directly to prefill kernels.
  • Comparison with single-item scoring: Instead of building N sequences of Query + Item_i, we create one sequence with a delimiter marking item boundaries. This reduces duplicated compute/memory traffic for the shared query and enables better batching behavior.

✍️ Example

Given:

{
  "query": "What is the capital of California? Answer Yes or No for each of the following options:",
  "items": ["Sacramento", "San Jose", "San Francisco"]
}
  • Single-item scoring (existing): N sequences
    • "…query… Sacramento"
    • "…query… San Jose"
    • "…query… San Francisco"
  • Multi-item scoring (this PR): 1 sequence with delimiter token D
  • We extract scores at delimiter boundaries for each item (skipping the first delimiter after the query).
What is the capital of California? Answer Yes or No for each of the following options:
<DELIM>Sacramento<DELIM>San Jose<DELIM>San Francisco<DELIM>

Performance Impact:

  • On Qwen3-0.6B with 300 input tokens, at QPS 120 and 10 items per request, P99 latency improved from 8276 ms to 511 ms (~16.2× faster, ~93.8% reduction) with this PR.
  • With a P99 latency threshold of 500 ms, throughput increased from 950 to 1200 items/s per H100 GPU (~26.3% increase).
image

Flashinfer Custom Mask

image

🔧 Modifications

  • New server arg: --multi-item-scoring-delimiter <int> (token id) with validation that requires:
    • --disable-radix-cache
    • --chunked-prefill-size -1
  • FlashInfer backend integration (python/sglang/srt/layers/attention/flashinfer_backend.py):
    • Introduce MultiItemScoringParams and compute per-prompt tensors:
      • prefix_len_ptr (uint32): prefix length before the first delimiter
      • token_pos_in_items_ptr (uint16): concatenated relative positions with 0 at each delimiter
      • token_pos_in_items_len (int): padded length for batch
      • max_item_len_ptr (uint16): max item length per prompt
    • Force paged prefill wrapper and disable ragged when multi-item scoring is active.
    • Disable sliding window attention for multi-item sequences to avoid crossing item boundaries.
    • Pass prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr to BatchPrefillWithPagedKVCacheWrapper.begin_forward.

⚠️ Limitations / Future work

  • Only FlashInfer backend support on prefill.
  • Sliding window disabled in multi-item mode by design.
  • Future work:
    • Support enabling radix cache.
    • Use FA3 attention in FlashInfer.
    • Provide helper to auto-select a delimiter token per tokenizer.

Accuracy Tests

  • Scores with Single Item Scoring
$ python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca --port 30000 --host 0.0.0.0 --chunked-prefill-size -1 --enable-torch-compile --dtype float16 --max-prefill-tokens 30000 --mem-fraction-static 0.3 --enable-dynamic-batch-tokenizer --disable-radix-cache --disable-cuda-graph   --attention-backend flashinfer
$ curl -X POST "http://localhost:30000/v1/score"   -H "Content-Type: application/json"   -d '{
    "query": "What is the capital of California? Answer Yes or No for each of the following options:",
    "items": ["Scaramento", "San Jose", "San Francisco"],
    "label_token_ids": [9454, 2753],
    "model": "/shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca"
  }' | jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   579  100   278  100   301   3416   3699 --:--:-- --:--:-- --:--:--  7148
{
  "scores": [
    [
      4.2720747888470666e-06,
      1.255632607448268e-05
    ],
    [
      7.248083212607261e-05,
      0.00032230865503649693
    ],
    [
      0.00012109026433961085,
      0.0003068084936522053
    ]
  ],
  "model": "/shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca",
  "usage": null,
  "object": "scoring"
}
  • Scores with Multi Item Scoring
$ python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca --port 30000 --host 0.0.0.0 --chunked-prefill-size -1 --enable-torch-compile --dtype float16 --max-prefill-tokens 30000 --mem-fraction-static 0.3 --enable-dynamic-batch-tokenizer --disable-radix-cache --disable-cuda-graph --multi-item-scoring-delimiter 151655  --attention-backend flashinfer
$ curl -X POST "http://localhost:30000/v1/score"   -H "Content-Type: application/json"   -d '{
    "query": "What is the capital of California? Answer Yes or No for each of the following options:",
    "items": ["Scaramento", "San Jose", "San Francisco"],
    "label_token_ids": [9454, 2753],
    "model": "/shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca"                             
  }' | jq
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   578  100   277  100   301   4650   5053 --:--:-- --:--:-- --:--:--  9796
{
  "scores": [
    [
      4.261198406544089e-06,
      1.2330186426487146e-05
    ],
    [
      7.449767639638031e-05,
      0.0003286991634328124
    ],
    [
      0.0001230619247565629,
      0.00031424963374729217
    ]
  ],
  "model": "/shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca",
  "usage": null,
  "object": "scoring"
}

Benchmarking and Profiling

🧪 Benchmark Comparison: Qwen3-0.6B on H100 (CUDA 12.8)

Setup:

  • Model: Qwen3-0.6B
  • Prompt length: 300 tokens
  • Hardware: H100 GPU
  • Duration: 120s
  • Target RPS: 70 , 80, 90, 100
  • Item Count: 10 per request
  • Distribution: Poisson

Server Start:

  • For Single Item Scoring using flashinfer backend (not FA3)
$ python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca --port 30000 --host 0.0.0.0 --chunked-prefill-size -1 --enable-torch-compile --dtype float16 --max-prefill-tokens 30000 --mem-fraction-static 0.3 --enable-tokenizer-batch-encode --disable-radix-cache --disable-cuda-graph   --attention-backend flashinfer
  • For Multi-item scoring
$ python -m sglang.launch_server --model-path /shared/public/elr-models/Qwen/Qwen3-0.6B/c1899de289a04d12100db370d81485cdf75e47ca --port 30000 --host 0.0.0.0 --chunked-prefill-size -1 --enable-torch-compile --dtype float16 --max-prefill-tokens 30000 --mem-fraction-static 0.3 --enable-dynamic-batch-tokenizer --disable-radix-cache --disable-cuda-graph --multi-item-scoring-delimiter 151655  --attention-backend flashinfer

Benchmark Script:

python3.10 sglang/benchmark/score/bench_score.py
Items Per Second Single Item Scoring P99 Latency (ms) Multi-Item Scoring P99 Latency (ms)
800 257 146
900 307 152
1000 601 227
1100 4226 315
1200 8276 511
1300 12556 2407

Checklist

Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, thanks for the efforts! Two general comments: (1) is the radix cache for prefix caching case properly handled especially cache eviction in radix tree? (2) can we add a small test case for three cases: single req, batched req, prefix caching cases. That should be enough to cover most cases, when hooking with the flashinfer kernels we added previously. Btw, make sure we switch to flashinfer backend always and disable chunked prefill when using this mode. Thank you!

@sundar24295s
Copy link
Collaborator Author

LGTM in general, thanks for the efforts! Two general comments: (1) is the radix cache for prefix caching case properly handled especially cache eviction in radix tree? (2) can we add a small test case for three cases: single req, batched req, prefix caching cases. That should be enough to cover most cases, when hooking with the flashinfer kernels we added previously. Btw, make sure we switch to flashinfer backend always and disable chunked prefill when using this mode. Thank you!

@qingquansong

  1. I have mentioned in the "Future work" section of PR description to add radix cache support. Will put up a separate PR for it.
  2. Yeah, adding unit tests after the initial review and alignment.

Copy link
Collaborator

@qingquansong qingquansong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@zhyncs zhyncs merged commit 53bd00d into sgl-project:main Oct 9, 2025
125 of 132 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants