Skip to content
38 changes: 33 additions & 5 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,20 +1538,48 @@ def prepare_for_idle(self):
self.model_config.vocab_size,
)

def prepare_for_decode(self):
def prepare_for_speculative_decode(self):
self.forward_mode = ForwardMode.DECODE

speculative_num_steps = global_server_args_dict["speculative_num_steps"]
speculative_eagle_topk = global_server_args_dict["speculative_eagle_topk"]
speculative_num_draft_tokens = global_server_args_dict["speculative_num_draft_tokens"]
page_size = self.token_to_kv_pool_allocator.page_size
bs = len(self.reqs)

if self.spec_algorithm.is_eagle():
assert self.token_to_kv_pool_allocator.page_size == 1, "Eagle only supports page size 1"
if page_size == 1:
self.draft_out_cache_loc, backup_state = self.alloc_token_slots(
bs * global_server_args_dict["speculative_num_steps"] * global_server_args_dict["speculative_eagle_topk"],
bs * speculative_num_steps * speculative_eagle_topk,
backup_state=True
)
self.token_to_kv_pool_allocator.restore_state(backup_state)
self.out_cache_loc = self.alloc_token_slots(bs * speculative_num_draft_tokens)
else:
max_draft_len = page_size - 1 + speculative_num_steps
num_new_pages_per_topk = (max_draft_len - 1) // page_size + 1
self.draft_out_cache_loc, backup_state = self.alloc_paged_token_slots_extend(
prefix_lens=torch.zeros_like(self.seq_lens),
seq_lens=torch.full_like(self.seq_lens, speculative_eagle_topk * num_new_pages_per_topk * page_size),
last_loc=torch.full_like(self.seq_lens, -1),
extend_num_tokens=bs * speculative_eagle_topk * num_new_pages_per_topk * page_size,
backup_state=True
)
self.token_to_kv_pool_allocator.restore_state(backup_state)
self.out_cache_loc = self.alloc_token_slots(bs * global_server_args_dict["speculative_num_draft_tokens"])
self.out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens=torch.zeros_like(self.seq_lens),
seq_lens=torch.full_like(self.seq_lens, speculative_num_draft_tokens),
last_loc=torch.full_like(self.seq_lens, -1),
extend_num_tokens=bs * speculative_num_draft_tokens,
)

def prepare_for_decode(self):
if self.spec_algorithm.is_eagle():
self.prepare_for_speculative_decode()
return

self.forward_mode = ForwardMode.DECODE
bs = len(self.reqs)

if self.sampling_info.penalizer_orchestrator.is_required:
if self.enable_overlap:
# TODO: this can be slow, optimize this.
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
class GenerationBatchResult:
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
free_cache_loc_cpu: Optional[torch.Tensor]
evict_cache_loc: Optional[torch.Tensor]
next_token_ids: Optional[List[int]]
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
Expand Down Expand Up @@ -1764,6 +1764,7 @@ def run_batch(

model_worker_batch = batch.get_model_worker_batch()
if self.enable_overlap:
# TODO (timmy): Do not alias seq_lens between forward and scheduler threads.
# Optimistically estimate the seq_lens_cpu for the next draft forward
model_worker_batch.seq_lens_cpu.add_(self.server_args.speculative_num_steps + 1)

Expand All @@ -1775,7 +1776,7 @@ def run_batch(
(
logits_output,
next_token_ids,
free_cache_loc_cpu,
evict_cache_loc,
bid,
can_run_cuda_graph,
next_spec_info,
Expand Down Expand Up @@ -1806,7 +1807,7 @@ def run_batch(
if not self.pp_group.is_last_rank
else None
),
free_cache_loc_cpu=free_cache_loc_cpu if self.spec_algorithm.is_eagle() else None,
evict_cache_loc=evict_cache_loc if self.spec_algorithm.is_eagle() else None,
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
Expand Down
34 changes: 14 additions & 20 deletions python/sglang/srt/managers/scheduler_output_processor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,16 @@ def process_batch_result_decode(
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
):
logits_output, next_token_ids, free_cache_loc_cpu, can_run_cuda_graph = (
logits_output, next_token_ids, evict_cache_loc, can_run_cuda_graph = (
result.logits_output,
result.next_token_ids,
result.free_cache_loc_cpu,
result.evict_cache_loc,
result.can_run_cuda_graph,
)

if self.enable_overlap:
if self.spec_algorithm.is_eagle():
logits_output, next_token_ids, free_cache_loc_cpu, _, can_run_cuda_graph = (
logits_output, next_token_ids, evict_cache_loc, _, can_run_cuda_graph = (
self.draft_worker.resolve_last_batch_result(launch_done)
)
else:
Expand All @@ -226,42 +226,36 @@ def process_batch_result_decode(

self.token_to_kv_pool_allocator.free_group_begin()

if free_cache_loc_cpu is not None:
free_cache_loc_cpu = free_cache_loc_cpu[free_cache_loc_cpu != 0]
self.token_to_kv_pool_allocator.free(free_cache_loc_cpu.to("cuda", non_blocking=True))
if evict_cache_loc is not None:
evict_cache_loc = evict_cache_loc[evict_cache_loc != 0]
self.token_to_kv_pool_allocator.free(evict_cache_loc)

if self.spec_algorithm.is_eagle():
# TODO (timmy): when does this happen?
if batch.seq_lens is not None:
batch.seq_lens.add_(logits_output.accept_length + 1)

accept_length = logits_output.accept_length.tolist()
idx_to_batch = [i for i, length in enumerate(accept_length) for _ in range(length + 1)]
bids = [(bid, step) for bid, length in enumerate(accept_length) for step in range(length + 1)]
else:
idx_to_batch = list(range(len(batch.reqs)))
bids = [(bid, 0) for bid in range(len(batch.reqs))]

num_generated_tokens_this_batch = len(idx_to_batch)
num_generated_tokens_this_batch = len(bids)
self.num_generated_tokens += num_generated_tokens_this_batch
if self.spec_algorithm.is_eagle():
self.spec_num_total_accepted_tokens += num_generated_tokens_this_batch
self.spec_num_total_forward_ct += len(batch.reqs)

# Check finish condition
for i, (b, next_token_id) in enumerate(zip(idx_to_batch, next_token_ids)):
req = batch.reqs[b]
prev_seq_lens = [batch.reqs[bid].seqlen for bid in range(len(batch.reqs))]
for i, ((bid, step), next_token_id) in enumerate(zip(bids, next_token_ids)):
req = batch.reqs[bid]
if req.is_retracted:
continue

if (self.enable_overlap or self.spec_algorithm.is_eagle()) and req.finished():
# Free the one extra delayed token
if self.page_size == 1:
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
elif self.spec_algorithm.is_none():
else:
# Only free when the extra token is in a new page
# NOTE (timmy): do we do anything for eagle?
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
if (prev_seq_lens[bid] + step - 1) % self.page_size == 0:
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
Expand Down
Loading
Loading