32
32
from sglang .srt .mem_cache .memory_pool import ReqToTokenPool
33
33
from sglang .srt .mem_cache .radix_cache import (
34
34
RadixKey ,
35
+ _convert_to_bigram_key ,
35
36
_key_match_page_size1 ,
36
37
_key_match_paged ,
37
38
get_child_key ,
@@ -327,12 +328,14 @@ def __init__(
327
328
sliding_window_size : int ,
328
329
page_size : int ,
329
330
disable : bool = False ,
331
+ is_eagle : bool = False ,
330
332
):
331
333
assert isinstance (token_to_kv_pool_allocator , SWATokenToKVPoolAllocator )
332
334
self .req_to_token_pool = req_to_token_pool
333
335
self .token_to_kv_pool_allocator = token_to_kv_pool_allocator
334
336
self .page_size = page_size
335
337
self .disable = disable
338
+ self .is_eagle = is_eagle
336
339
337
340
if self .token_to_kv_pool_allocator :
338
341
self .device = self .token_to_kv_pool_allocator .device
@@ -346,6 +349,11 @@ def __init__(
346
349
self .key_match_fn = partial (_key_match_paged , page_size = page_size )
347
350
self .get_child_key_fn = partial (get_child_key , page_size = page_size )
348
351
352
+ if is_eagle :
353
+ self .key_convert_fn = _convert_to_bigram_key
354
+ else :
355
+ self .key_convert_fn = lambda key : key
356
+
349
357
self .sliding_window_size = sliding_window_size
350
358
self .reset ()
351
359
@@ -376,6 +384,8 @@ def match_prefix(self, key: RadixKey, **kwargs) -> MatchResult:
376
384
The last node create a new child if the prefix is shorter
377
385
than the last node's value.
378
386
"""
387
+ key .token_ids = self .key_convert_fn (key .token_ids )
388
+
379
389
if self .disable or len (key ) == 0 :
380
390
return MatchResult (
381
391
device_indices = torch .empty (
@@ -406,8 +416,15 @@ def insert(self, key: RadixKey, value=None, prev_prefix_len: int = 0) -> int:
406
416
if self .disable :
407
417
return 0
408
418
419
+ key .token_ids = self .key_convert_fn (key .token_ids )
420
+
409
421
if value is None :
410
422
value = torch .tensor ([x for x in key .token_ids ], dtype = torch .int64 )
423
+
424
+ if self .is_eagle :
425
+ # Make sure the value len equal to the EAGLE bigram key len
426
+ value = value [: len (key )]
427
+
411
428
return self ._insert_helper (self .root_node , key , value , prev_prefix_len )
412
429
413
430
def cache_finished_req (self , req : Req ) -> None :
@@ -422,25 +439,41 @@ def cache_finished_req(self, req: Req) -> None:
422
439
return
423
440
424
441
token_ids = (req .origin_input_ids + req .output_ids )[:- 1 ]
442
+ all_token_len = len (token_ids )
443
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
444
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
445
+ actual_kv_len = all_token_len - 1 if self .is_eagle else all_token_len
425
446
kv_indices = self .req_to_token_pool .req_to_token [
426
- req .req_pool_idx , : len ( token_ids )
447
+ req .req_pool_idx , :all_token_len
427
448
]
428
449
429
450
if self .page_size != 1 :
430
- page_aligned_len = len ( kv_indices ) // self .page_size * self .page_size
451
+ page_aligned_len = actual_kv_len // self .page_size * self .page_size
431
452
page_aligned_kv_indices = kv_indices [:page_aligned_len ].clone ()
432
453
self .token_to_kv_pool_allocator .free (kv_indices [page_aligned_len :])
433
454
else :
434
- page_aligned_len = len ( kv_indices )
455
+ page_aligned_len = actual_kv_len
435
456
page_aligned_kv_indices = kv_indices .clone ()
457
+ if self .is_eagle :
458
+ self .token_to_kv_pool_allocator .free (kv_indices [page_aligned_len :])
459
+
460
+ page_aligned_token_len = (
461
+ page_aligned_len + 1 if self .is_eagle else page_aligned_len
462
+ )
463
+
464
+ old_prefix_len = len (req .prefix_indices )
465
+ if self .is_eagle and old_prefix_len > req .last_matched_prefix_len :
466
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
467
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
468
+ old_prefix_len -= 1
436
469
437
470
# Radix Cache takes one ref in memory pool
438
471
# insert the token_ids and kv_indices into the radix tree
439
472
# Note: the insert function already frees the overlapped kv_indices
440
473
new_prefix_len = self .insert (
441
- RadixKey (token_ids [:page_aligned_len ], req .extra_key ),
474
+ RadixKey (token_ids [:page_aligned_token_len ], req .extra_key ),
442
475
page_aligned_kv_indices ,
443
- len ( req . prefix_indices ) ,
476
+ old_prefix_len ,
444
477
)
445
478
446
479
# Remove req slot release the cache lock
@@ -459,39 +492,56 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
459
492
return
460
493
461
494
token_ids = req .fill_ids
495
+ all_token_len = len (token_ids )
496
+ # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
497
+ # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
498
+ actual_kv_len = all_token_len - 1 if self .is_eagle else all_token_len
462
499
kv_indices = self .req_to_token_pool .req_to_token [
463
- req .req_pool_idx , : len ( token_ids )
500
+ req .req_pool_idx , :all_token_len
464
501
]
465
502
466
503
if self .page_size != 1 :
467
- page_aligned_len = len ( kv_indices ) // self .page_size * self .page_size
504
+ page_aligned_len = actual_kv_len // self .page_size * self .page_size
468
505
page_aligned_kv_indices = kv_indices [:page_aligned_len ].clone ()
469
506
else :
470
- page_aligned_len = len ( kv_indices )
507
+ page_aligned_len = actual_kv_len
471
508
page_aligned_kv_indices = kv_indices .clone ()
472
- page_aligned_token_ids = token_ids [:page_aligned_len ]
509
+
510
+ # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
511
+ page_aligned_token_len = (
512
+ page_aligned_len + 1 if self .is_eagle else page_aligned_len
513
+ )
514
+ page_aligned_token_ids = token_ids [:page_aligned_token_len ]
515
+
516
+ old_prefix_len = len (req .prefix_indices )
517
+ if self .is_eagle and old_prefix_len > req .last_matched_prefix_len :
518
+ # In EAGLE chunked prefill case, the prefix_indices included one unmatched token (kv_indices[actual_kv_len:])
519
+ # Here we -1 to make sure the kv of the unmatched token can be freed correctly to avoid memory leak
520
+ old_prefix_len -= 1
473
521
474
522
# Radix Cache takes one ref in memory pool
475
523
# Note: the insert function already frees the overlapped kv_indices
476
524
new_prefix_len = self .insert (
477
525
RadixKey (page_aligned_token_ids , req .extra_key ),
478
526
page_aligned_kv_indices ,
479
- len ( req . prefix_indices ) ,
527
+ old_prefix_len ,
480
528
)
481
529
482
530
# The prefix indices could be updated, reuse it
483
531
new_indices , new_last_node , _ , _ = self .match_prefix (
484
532
RadixKey (page_aligned_token_ids , req .extra_key )
485
533
)
486
- assert len ( req . prefix_indices ) <= len (
534
+ assert old_prefix_len <= len (
487
535
new_indices
488
536
), f"{ req .prefix_indices = } , { new_indices = } "
489
537
assert new_prefix_len <= len (new_indices ), f"{ new_prefix_len = } , { new_indices = } "
490
538
self .req_to_token_pool .write (
491
- (req .req_pool_idx , slice (len ( req . prefix_indices ) , len (new_indices ))),
492
- new_indices [len ( req . prefix_indices ) :],
539
+ (req .req_pool_idx , slice (old_prefix_len , len (new_indices ))),
540
+ new_indices [old_prefix_len :],
493
541
)
494
542
543
+ req .last_matched_prefix_len = len (new_indices )
544
+
495
545
self .dec_lock_ref (req .last_node , req .swa_uuid_for_lock )
496
546
swa_uuid_for_lock = self .inc_lock_ref (new_last_node )
497
547
@@ -501,7 +551,13 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None:
501
551
[new_indices , kv_indices [len (new_indices ) :]]
502
552
)
503
553
else :
504
- req .prefix_indices = new_indices
554
+ if self .is_eagle :
555
+ # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
556
+ req .prefix_indices = torch .cat (
557
+ [new_indices , kv_indices [actual_kv_len :]]
558
+ )
559
+ else :
560
+ req .prefix_indices = new_indices
505
561
req .last_node = new_last_node
506
562
req .swa_uuid_for_lock = swa_uuid_for_lock
507
563
0 commit comments