Skip to content

Commit 84b0c85

Browse files
sundar24295sch-tiger1
authored andcommitted
[Generative Score API] Multi-Item scoring with custom attention mask. (sgl-project#10979)
1 parent 988be49 commit 84b0c85

File tree

10 files changed

+1121
-129
lines changed

10 files changed

+1121
-129
lines changed

python/sglang/srt/layers/attention/flashinfer_backend.py

Lines changed: 223 additions & 6 deletions
Large diffs are not rendered by default.

python/sglang/srt/layers/logits_processor.py

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@
6060
class LogitsProcessorOutput:
6161
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
6262
# The logits of the next tokens. shape: [#seq, vocab_size]
63-
next_token_logits: torch.Tensor
63+
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
64+
next_token_logits: Optional[torch.Tensor]
6465
# Used by speculative decoding (EAGLE)
6566
# The last hidden layers
6667
hidden_states: Optional[torch.Tensor] = None
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
8586
input_top_logprobs_val: List = None
8687
input_top_logprobs_idx: List = None
8788
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
88-
input_token_ids_logprobs_val: Optional[List] = None
89+
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
90+
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
91+
None
92+
)
8993
input_token_ids_logprobs_idx: Optional[List] = None
9094

9195

@@ -127,6 +131,9 @@ class LogitsMetadata:
127131
# for padding
128132
padded_static_len: int = -1
129133

134+
# Whether this batch is prefill-only (no token generation needed)
135+
is_prefill_only: bool = False
136+
130137
@classmethod
131138
def from_forward_batch(cls, forward_batch: ForwardBatch):
132139
if (
@@ -169,6 +176,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
169176
token_ids_logprobs=forward_batch.token_ids_logprobs,
170177
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
171178
padded_static_len=forward_batch.padded_static_len,
179+
is_prefill_only=forward_batch.is_prefill_only,
172180
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
173181
dp_local_start_pos=forward_batch.dp_local_start_pos,
174182
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -247,6 +255,108 @@ def __init__(
247255
"debug_tensor_dump_output_folder", None
248256
)
249257

258+
def compute_logprobs_for_multi_item_scoring(
259+
self,
260+
input_ids,
261+
hidden_states,
262+
lm_head: VocabParallelEmbedding,
263+
logits_metadata: Union[LogitsMetadata, ForwardBatch],
264+
delimiter_token: int,
265+
):
266+
"""
267+
Compute logprobs for multi-item scoring using delimiter-based token extraction.
268+
269+
This method is designed for scenarios where you want to score multiple items/candidates
270+
against a single query by combining them into one sequence separated by delimiters.
271+
272+
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
273+
Scoring positions: Extracts logprobs at positions before each <delimiter>
274+
275+
Args:
276+
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
277+
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
278+
hidden_states (torch.Tensor): Hidden states from the model.
279+
Shape: [sequence_length, hidden_dim].
280+
lm_head (VocabParallelEmbedding): Language model head for computing logits.
281+
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
282+
and token ID specifications for logprob extraction.
283+
delimiter_token (int): Token ID used as delimiter between query and items.
284+
285+
Returns:
286+
LogitsProcessorOutput: Contains:
287+
- next_token_logits: None (not needed for scoring-only requests)
288+
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
289+
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
290+
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
291+
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
292+
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
293+
"""
294+
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
295+
0
296+
] - 1
297+
# Extract hidden states at delimiter positions for multi-item scoring
298+
sliced_hidden = hidden_states[multi_item_indices]
299+
300+
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
301+
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
302+
303+
# Initialize return values
304+
input_token_ids_logprobs_val = []
305+
input_token_ids_logprobs_idx = []
306+
input_top_logprobs_val = None
307+
input_top_logprobs_idx = None
308+
309+
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
310+
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
311+
if (
312+
logits_metadata.token_ids_logprobs
313+
or logits_metadata.extend_return_top_logprob
314+
):
315+
logits_metadata.extend_logprob_pruned_lens_cpu = []
316+
317+
if logits_metadata.extend_seq_lens_cpu is not None:
318+
# Multi-request batch: count delimiters per request
319+
input_pt = 0
320+
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
321+
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
322+
delimiter_count = (req_input_ids == delimiter_token).sum().item()
323+
logits_metadata.extend_logprob_pruned_lens_cpu.append(
324+
delimiter_count
325+
)
326+
input_pt += req_seq_len
327+
else:
328+
# Single request case: one request gets all delimiters
329+
total_delimiters = (input_ids == delimiter_token).sum().item()
330+
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
331+
332+
# Get the logprobs of specified token ids
333+
if logits_metadata.extend_token_ids_logprob:
334+
(
335+
input_token_ids_logprobs_val,
336+
input_token_ids_logprobs_idx,
337+
) = self.get_token_ids_logprobs(
338+
sliced_logprobs, logits_metadata, delay_cpu_copy=True
339+
)
340+
341+
# Get the logprob of top-k tokens
342+
if logits_metadata.extend_return_top_logprob:
343+
(
344+
input_top_logprobs_val,
345+
input_top_logprobs_idx,
346+
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
347+
348+
# For input_token_logprobs, use delimiter token logprobs
349+
input_token_logprobs = sliced_logprobs[:, delimiter_token]
350+
351+
return LogitsProcessorOutput(
352+
next_token_logits=None, # Multi-item scoring doesn't need next token logits
353+
input_token_logprobs=input_token_logprobs,
354+
input_top_logprobs_val=input_top_logprobs_val,
355+
input_top_logprobs_idx=input_top_logprobs_idx,
356+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
357+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
358+
)
359+
250360
def forward(
251361
self,
252362
input_ids,
@@ -257,6 +367,16 @@ def forward(
257367
) -> LogitsProcessorOutput:
258368
if isinstance(logits_metadata, ForwardBatch):
259369
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
370+
371+
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
372+
multi_item_delimiter = global_server_args_dict.get(
373+
"multi_item_scoring_delimiter"
374+
)
375+
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
376+
return self.compute_logprobs_for_multi_item_scoring(
377+
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
378+
)
379+
260380
# Get the last hidden states and last logits for the next token prediction
261381
if (
262382
logits_metadata.forward_mode.is_decode_or_idle()
@@ -584,7 +704,9 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
584704

585705
@staticmethod
586706
def get_token_ids_logprobs(
587-
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
707+
all_logprobs: torch.Tensor,
708+
logits_metadata: LogitsMetadata,
709+
delay_cpu_copy: bool = False,
588710
):
589711
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
590712
pt = 0
@@ -597,9 +719,17 @@ def get_token_ids_logprobs(
597719
input_token_ids_logprobs_idx.append([])
598720
continue
599721

600-
input_token_ids_logprobs_val.append(
601-
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
602-
)
722+
position_logprobs = all_logprobs[
723+
pt : pt + pruned_len, token_ids
724+
] # Shape: [pruned_len, num_tokens]
725+
726+
if delay_cpu_copy:
727+
# Keep as tensor to delay GPU-to-CPU transfer
728+
input_token_ids_logprobs_val.append(position_logprobs)
729+
else:
730+
# Convert to list immediately (default behavior)
731+
input_token_ids_logprobs_val.append(position_logprobs.tolist())
732+
603733
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
604734
pt += pruned_len
605735

python/sglang/srt/managers/schedule_batch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@
118118
"fake_node",
119119
"nsa_prefill",
120120
"nsa_decode",
121+
"multi_item_scoring_delimiter",
121122
]
122123

123124
# Put some global args for easy access
@@ -670,9 +671,11 @@ def seqlen(self):
670671
def is_prefill_only(self) -> bool:
671672
"""Check if this request is prefill-only (no token generation needed)."""
672673
# NOTE: when spec is enabled, prefill_only optimizations are disabled
673-
return (
674-
self.sampling_params.max_new_tokens == 0
675-
and global_server_args_dict["speculative_algorithm"] is None
674+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
675+
676+
spec_alg = global_server_args_dict["speculative_algorithm"]
677+
return self.sampling_params.max_new_tokens == 0 and (
678+
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
676679
)
677680

678681
def add_latency(self, stage: RequestStage):

0 commit comments

Comments
 (0)