60
60
class LogitsProcessorOutput :
61
61
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
62
62
# The logits of the next tokens. shape: [#seq, vocab_size]
63
- next_token_logits : torch .Tensor
63
+ # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
64
+ next_token_logits : Optional [torch .Tensor ]
64
65
# Used by speculative decoding (EAGLE)
65
66
# The last hidden layers
66
67
hidden_states : Optional [torch .Tensor ] = None
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
85
86
input_top_logprobs_val : List = None
86
87
input_top_logprobs_idx : List = None
87
88
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
88
- input_token_ids_logprobs_val : Optional [List ] = None
89
+ # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
90
+ input_token_ids_logprobs_val : Optional [List [Union [List [float ], torch .Tensor ]]] = (
91
+ None
92
+ )
89
93
input_token_ids_logprobs_idx : Optional [List ] = None
90
94
91
95
@@ -127,6 +131,9 @@ class LogitsMetadata:
127
131
# for padding
128
132
padded_static_len : int = - 1
129
133
134
+ # Whether this batch is prefill-only (no token generation needed)
135
+ is_prefill_only : bool = False
136
+
130
137
@classmethod
131
138
def from_forward_batch (cls , forward_batch : ForwardBatch ):
132
139
if (
@@ -169,6 +176,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
169
176
token_ids_logprobs = forward_batch .token_ids_logprobs ,
170
177
extend_input_logprob_token_ids_gpu = forward_batch .extend_input_logprob_token_ids_gpu ,
171
178
padded_static_len = forward_batch .padded_static_len ,
179
+ is_prefill_only = forward_batch .is_prefill_only ,
172
180
global_num_tokens_gpu = forward_batch .global_num_tokens_gpu ,
173
181
dp_local_start_pos = forward_batch .dp_local_start_pos ,
174
182
dp_local_num_tokens = forward_batch .dp_local_num_tokens ,
@@ -247,6 +255,108 @@ def __init__(
247
255
"debug_tensor_dump_output_folder" , None
248
256
)
249
257
258
+ def compute_logprobs_for_multi_item_scoring (
259
+ self ,
260
+ input_ids ,
261
+ hidden_states ,
262
+ lm_head : VocabParallelEmbedding ,
263
+ logits_metadata : Union [LogitsMetadata , ForwardBatch ],
264
+ delimiter_token : int ,
265
+ ):
266
+ """
267
+ Compute logprobs for multi-item scoring using delimiter-based token extraction.
268
+
269
+ This method is designed for scenarios where you want to score multiple items/candidates
270
+ against a single query by combining them into one sequence separated by delimiters.
271
+
272
+ Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
273
+ Scoring positions: Extracts logprobs at positions before each <delimiter>
274
+
275
+ Args:
276
+ input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
277
+ Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
278
+ hidden_states (torch.Tensor): Hidden states from the model.
279
+ Shape: [sequence_length, hidden_dim].
280
+ lm_head (VocabParallelEmbedding): Language model head for computing logits.
281
+ logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
282
+ and token ID specifications for logprob extraction.
283
+ delimiter_token (int): Token ID used as delimiter between query and items.
284
+
285
+ Returns:
286
+ LogitsProcessorOutput: Contains:
287
+ - next_token_logits: None (not needed for scoring-only requests)
288
+ - input_token_logprobs: Logprobs of delimiter tokens at scoring positions
289
+ - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
290
+ - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
291
+ - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
292
+ - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
293
+ """
294
+ multi_item_indices = (input_ids == delimiter_token ).nonzero (as_tuple = True )[
295
+ 0
296
+ ] - 1
297
+ # Extract hidden states at delimiter positions for multi-item scoring
298
+ sliced_hidden = hidden_states [multi_item_indices ]
299
+
300
+ sliced_logits = self ._get_logits (sliced_hidden , lm_head , logits_metadata )
301
+ sliced_logprobs = torch .nn .functional .log_softmax (sliced_logits , dim = - 1 )
302
+
303
+ # Initialize return values
304
+ input_token_ids_logprobs_val = []
305
+ input_token_ids_logprobs_idx = []
306
+ input_top_logprobs_val = None
307
+ input_top_logprobs_idx = None
308
+
309
+ # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
310
+ # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
311
+ if (
312
+ logits_metadata .token_ids_logprobs
313
+ or logits_metadata .extend_return_top_logprob
314
+ ):
315
+ logits_metadata .extend_logprob_pruned_lens_cpu = []
316
+
317
+ if logits_metadata .extend_seq_lens_cpu is not None :
318
+ # Multi-request batch: count delimiters per request
319
+ input_pt = 0
320
+ for req_seq_len in logits_metadata .extend_seq_lens_cpu :
321
+ req_input_ids = input_ids [input_pt : input_pt + req_seq_len ]
322
+ delimiter_count = (req_input_ids == delimiter_token ).sum ().item ()
323
+ logits_metadata .extend_logprob_pruned_lens_cpu .append (
324
+ delimiter_count
325
+ )
326
+ input_pt += req_seq_len
327
+ else :
328
+ # Single request case: one request gets all delimiters
329
+ total_delimiters = (input_ids == delimiter_token ).sum ().item ()
330
+ logits_metadata .extend_logprob_pruned_lens_cpu = [total_delimiters ]
331
+
332
+ # Get the logprobs of specified token ids
333
+ if logits_metadata .extend_token_ids_logprob :
334
+ (
335
+ input_token_ids_logprobs_val ,
336
+ input_token_ids_logprobs_idx ,
337
+ ) = self .get_token_ids_logprobs (
338
+ sliced_logprobs , logits_metadata , delay_cpu_copy = True
339
+ )
340
+
341
+ # Get the logprob of top-k tokens
342
+ if logits_metadata .extend_return_top_logprob :
343
+ (
344
+ input_top_logprobs_val ,
345
+ input_top_logprobs_idx ,
346
+ ) = self .get_top_logprobs (sliced_logprobs , logits_metadata )
347
+
348
+ # For input_token_logprobs, use delimiter token logprobs
349
+ input_token_logprobs = sliced_logprobs [:, delimiter_token ]
350
+
351
+ return LogitsProcessorOutput (
352
+ next_token_logits = None , # Multi-item scoring doesn't need next token logits
353
+ input_token_logprobs = input_token_logprobs ,
354
+ input_top_logprobs_val = input_top_logprobs_val ,
355
+ input_top_logprobs_idx = input_top_logprobs_idx ,
356
+ input_token_ids_logprobs_val = input_token_ids_logprobs_val ,
357
+ input_token_ids_logprobs_idx = input_token_ids_logprobs_idx ,
358
+ )
359
+
250
360
def forward (
251
361
self ,
252
362
input_ids ,
@@ -257,6 +367,16 @@ def forward(
257
367
) -> LogitsProcessorOutput :
258
368
if isinstance (logits_metadata , ForwardBatch ):
259
369
logits_metadata = LogitsMetadata .from_forward_batch (logits_metadata )
370
+
371
+ # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
372
+ multi_item_delimiter = global_server_args_dict .get (
373
+ "multi_item_scoring_delimiter"
374
+ )
375
+ if multi_item_delimiter is not None and logits_metadata .is_prefill_only :
376
+ return self .compute_logprobs_for_multi_item_scoring (
377
+ input_ids , hidden_states , lm_head , logits_metadata , multi_item_delimiter
378
+ )
379
+
260
380
# Get the last hidden states and last logits for the next token prediction
261
381
if (
262
382
logits_metadata .forward_mode .is_decode_or_idle ()
@@ -584,7 +704,9 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
584
704
585
705
@staticmethod
586
706
def get_token_ids_logprobs (
587
- all_logprobs : torch .Tensor , logits_metadata : LogitsMetadata
707
+ all_logprobs : torch .Tensor ,
708
+ logits_metadata : LogitsMetadata ,
709
+ delay_cpu_copy : bool = False ,
588
710
):
589
711
input_token_ids_logprobs_val , input_token_ids_logprobs_idx = [], []
590
712
pt = 0
@@ -597,9 +719,17 @@ def get_token_ids_logprobs(
597
719
input_token_ids_logprobs_idx .append ([])
598
720
continue
599
721
600
- input_token_ids_logprobs_val .append (
601
- [all_logprobs [pt + j , token_ids ].tolist () for j in range (pruned_len )]
602
- )
722
+ position_logprobs = all_logprobs [
723
+ pt : pt + pruned_len , token_ids
724
+ ] # Shape: [pruned_len, num_tokens]
725
+
726
+ if delay_cpu_copy :
727
+ # Keep as tensor to delay GPU-to-CPU transfer
728
+ input_token_ids_logprobs_val .append (position_logprobs )
729
+ else :
730
+ # Convert to list immediately (default behavior)
731
+ input_token_ids_logprobs_val .append (position_logprobs .tolist ())
732
+
603
733
input_token_ids_logprobs_idx .append ([token_ids for _ in range (pruned_len )])
604
734
pt += pruned_len
605
735
0 commit comments