Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 223 additions & 6 deletions python/sglang/srt/layers/attention/flashinfer_backend.py

Large diffs are not rendered by default.

142 changes: 136 additions & 6 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
next_token_logits: Optional[torch.Tensor]
# Used by speculative decoding (EAGLE)
# The last hidden layers
hidden_states: Optional[torch.Tensor] = None
Expand All @@ -85,7 +86,10 @@ class LogitsProcessorOutput:
input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: Optional[List] = None
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
None
)
input_token_ids_logprobs_idx: Optional[List] = None


Expand Down Expand Up @@ -127,6 +131,9 @@ class LogitsMetadata:
# for padding
padded_static_len: int = -1

# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False

@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
Expand Down Expand Up @@ -169,6 +176,7 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
is_prefill_only=forward_batch.is_prefill_only,
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
Expand Down Expand Up @@ -247,6 +255,108 @@ def __init__(
"debug_tensor_dump_output_folder", None
)

def compute_logprobs_for_multi_item_scoring(
self,
input_ids,
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
delimiter_token: int,
):
"""
Compute logprobs for multi-item scoring using delimiter-based token extraction.

This method is designed for scenarios where you want to score multiple items/candidates
against a single query by combining them into one sequence separated by delimiters.

Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
Scoring positions: Extracts logprobs at positions before each <delimiter>

Args:
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
hidden_states (torch.Tensor): Hidden states from the model.
Shape: [sequence_length, hidden_dim].
lm_head (VocabParallelEmbedding): Language model head for computing logits.
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
and token ID specifications for logprob extraction.
delimiter_token (int): Token ID used as delimiter between query and items.

Returns:
LogitsProcessorOutput: Contains:
- next_token_logits: None (not needed for scoring-only requests)
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
"""
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
0
] - 1
# Extract hidden states at delimiter positions for multi-item scoring
sliced_hidden = hidden_states[multi_item_indices]

sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)

# Initialize return values
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
input_top_logprobs_val = None
input_top_logprobs_idx = None

# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
if (
logits_metadata.token_ids_logprobs
or logits_metadata.extend_return_top_logprob
):
logits_metadata.extend_logprob_pruned_lens_cpu = []

if logits_metadata.extend_seq_lens_cpu is not None:
# Multi-request batch: count delimiters per request
input_pt = 0
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
delimiter_count = (req_input_ids == delimiter_token).sum().item()
logits_metadata.extend_logprob_pruned_lens_cpu.append(
delimiter_count
)
input_pt += req_seq_len
else:
# Single request case: one request gets all delimiters
total_delimiters = (input_ids == delimiter_token).sum().item()
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]

# Get the logprobs of specified token ids
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(
sliced_logprobs, logits_metadata, delay_cpu_copy=True
)

# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)

# For input_token_logprobs, use delimiter token logprobs
input_token_logprobs = sliced_logprobs[:, delimiter_token]

return LogitsProcessorOutput(
next_token_logits=None, # Multi-item scoring doesn't need next token logits
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)

def forward(
self,
input_ids,
Expand All @@ -257,6 +367,16 @@ def forward(
) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)

# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get(
"multi_item_scoring_delimiter"
)
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
)

# Get the last hidden states and last logits for the next token prediction
if (
logits_metadata.forward_mode.is_decode_or_idle()
Expand Down Expand Up @@ -584,7 +704,9 @@ def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata

@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
all_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
delay_cpu_copy: bool = False,
):
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
Expand All @@ -597,9 +719,17 @@ def get_token_ids_logprobs(
input_token_ids_logprobs_idx.append([])
continue

input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
)
position_logprobs = all_logprobs[
pt : pt + pruned_len, token_ids
] # Shape: [pruned_len, num_tokens]

if delay_cpu_copy:
# Keep as tensor to delay GPU-to-CPU transfer
input_token_ids_logprobs_val.append(position_logprobs)
else:
# Convert to list immediately (default behavior)
input_token_ids_logprobs_val.append(position_logprobs.tolist())

input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len

Expand Down
9 changes: 6 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
"multi_item_scoring_delimiter",
]

# Put some global args for easy access
Expand Down Expand Up @@ -666,9 +667,11 @@ def seqlen(self):
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
# NOTE: when spec is enabled, prefill_only optimizations are disabled
return (
self.sampling_params.max_new_tokens == 0
and global_server_args_dict["speculative_algorithm"] is None
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm

spec_alg = global_server_args_dict["speculative_algorithm"]
return self.sampling_params.max_new_tokens == 0 and (
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
)

def add_latency(self, stage: RequestStage):
Expand Down
Loading
Loading