Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
43dfe76
Implement LRU eviction policy for LoRA adapters
ConnorLi96 Sep 29, 2025
06eaf34
feat: Add LRU eviction policy with comprehensive unit tests
ConnorLi96 Sep 29, 2025
8e7afe1
style: Fix code formatting and spelling issues
ConnorLi96 Sep 30, 2025
e6ae718
change default fifo to lru
ConnorLi96 Sep 30, 2025
77d8051
address comments and optmize code
ConnorLi96 Oct 1, 2025
65628b5
correct the engine command help=
ConnorLi96 Oct 1, 2025
3d6fdd2
support integration test
ConnorLi96 Oct 1, 2025
f1c446e
evict None base model
ConnorLi96 Oct 3, 2025
bfe931f
add DEFAULT_LORA_EVICTION_POLICY
ConnorLi96 Oct 3, 2025
76290d0
address some comments
ConnorLi96 Oct 7, 2025
a0b3a74
add more unit tests
ConnorLi96 Oct 7, 2025
15e1d8a
Merge branch 'main' into feature/sglang_lora_lru
ConnorLi96 Oct 7, 2025
03584ba
Update server_args.py
ConnorLi96 Oct 7, 2025
ca58c7f
fix format
ConnorLi96 Oct 7, 2025
11a6a1b
Merge branch 'main' into feature/sglang_lora_lru
Fridge003 Oct 7, 2025
721421e
Merge branch 'main' into feature/sglang_lora_lru
ConnorLi96 Oct 7, 2025
f221f60
fix wrong import
ConnorLi96 Oct 10, 2025
0921776
Merge branch 'main' into feature/sglang_lora_lru
ConnorLi96 Oct 11, 2025
56efbfb
Merge branch 'main' into feature/sglang_lora_lru
Fridge003 Oct 13, 2025
94b8bd2
Merge branch 'main' into feature/sglang_lora_lru
ConnorLi96 Oct 13, 2025
9094954
delete integration test in test_lora_eviction_policy.py
ConnorLi96 Oct 13, 2025
9349cb7
update arguments
ConnorLi96 Oct 13, 2025
8ccef17
update server_arguments.md
ConnorLi96 Oct 13, 2025
ed5de18
update lora.ipynb
ConnorLi96 Oct 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions python/sglang/srt/lora/eviction_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""
Eviction policies for LoRA adapter memory management.
"""

import logging
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set

logger = logging.getLogger(__name__)


class EvictionPolicy(ABC):
"""Abstract base class for LoRA adapter eviction policies."""

@abstractmethod
def mark_used(self, uid: Optional[str]) -> None:
"""Marks an adapter as used."""
pass

@abstractmethod
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
"""Selects an adapter to evict from candidates."""
pass

@abstractmethod
def remove(self, uid: Optional[str]) -> None:
"""Removes an adapter from the policy's tracking."""
pass


class LRUEvictionPolicy(EvictionPolicy):
"""LRU eviction policy - evicts the least recently used adapter."""

def __init__(self):
self.access_order = OrderedDict() # key=uid, value=last_access_time
self.total_accesses = 0
self.eviction_count = 0

def mark_used(self, uid: Optional[str]) -> None:
if uid is not None:
current_time = time.monotonic()
# Remove and re-add to move to end (most recent)
self.access_order.pop(uid, None)
self.access_order[uid] = current_time
self.total_accesses += 1
logger.debug(f"LoRA {uid} marked as used at {current_time}")

def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
"""Select the least recently used adapter from candidates."""
# Iterate through access_order (oldest first) to find LRU victim
for uid in list(self.access_order.keys()):
if uid in candidates:
logger.debug(f"Selected LoRA {uid} for eviction (LRU)")
self.eviction_count += 1
return uid
return None

def remove(self, uid: Optional[str]) -> None:
if uid is not None:
self.access_order.pop(uid, None)
logger.debug(f"Removed LoRA {uid} from LRU tracking")


class FIFOEvictionPolicy(EvictionPolicy):
"""FIFO eviction policy - for backward compatibility."""

def __init__(self):
self.insertion_order = OrderedDict() # key=uid, value=insertion_time
self.eviction_count = 0

def mark_used(self, uid: Optional[str]) -> None:
"""For FIFO, we only track insertion (timestamp)"""
if uid is not None and uid not in self.insertion_order:
self.insertion_order[uid] = time.monotonic()

def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
"""Select the first inserted adapter from candidates."""
# Iterate through insertion_order (oldest first) to find FIFO victim
for uid in list(self.insertion_order.keys()):
if uid in candidates:
logger.debug(f"Selected LoRA {uid} for eviction (FIFO)")
self.eviction_count += 1
return uid
return None

def remove(self, uid: Optional[str]) -> None:
if uid is not None:
self.insertion_order.pop(uid, None)


def get_eviction_policy(policy_name: str) -> EvictionPolicy:
"""Factory function to create eviction policy instances."""
policies = {
"fifo": FIFOEvictionPolicy,
"lru": LRUEvictionPolicy,
}
if policy_name not in policies:
raise ValueError(f"Unknown eviction policy: {policy_name}")
return policies[policy_name]()
4 changes: 4 additions & 0 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank

# Store eviction policy from server args
self.eviction_policy = server_args.lora_eviction_policy

# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend)
Expand Down Expand Up @@ -411,6 +414,7 @@ def init_memory_pool(self):
max_lora_rank=self.max_lora_rank,
target_modules=self.target_modules,
base_model=self.base_model,
eviction_policy=self.eviction_policy,
)

def set_lora_module(self, module_name, module):
Expand Down
52 changes: 35 additions & 17 deletions python/sglang/srt/lora/mem_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.eviction_policy import get_eviction_policy
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
max_lora_rank: int,
target_modules: Set[str],
base_model: torch.nn.Module,
eviction_policy: str,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
Expand All @@ -64,6 +66,9 @@ def __init__(
self.max_lora_rank: int = max_lora_rank
self.target_modules: Set[str] = target_modules

# Initialize eviction policy
self.eviction_policy = get_eviction_policy(eviction_policy)

# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
Expand Down Expand Up @@ -189,31 +194,44 @@ def prepare_lora_batch(
lora_refs: Dict[str, LoRARef],
):
def get_available_buffer_slot():
# 1. Prioritize empty slots
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id

# 2. Memory pool is full, need to evict using policy
# Collect eviction candidates (not in current batch and not pinned)
candidates = set()

for buffer_id in range(self.max_loras_per_batch):
uid = self.buffer_id_to_uid[buffer_id]
if uid not in cur_uids and uid is not None:
lora_ref = lora_refs.get(uid)
# Only add to candidates if not pinned
if lora_ref is None or not lora_ref.pinned:
candidates.add(uid)

# Use eviction policy to select victim
victim_uid = self.eviction_policy.select_victim(candidates)

if victim_uid is None:
raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)

# Evict unneeded lora
if uid not in cur_uids:
# Skip pinned LoRAs
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref is not None and lora_ref.pinned:
continue

self.uid_to_buffer_id.pop(uid)
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
return buffer_id

raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
# Evict the selected victim
victim_buffer_id = self.uid_to_buffer_id[victim_uid]
self.uid_to_buffer_id.pop(victim_uid)
self.eviction_policy.remove(victim_uid)
self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT
logger.debug(
f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}."
)
return victim_buffer_id

# Mark all adapters in current batch as used (for LRU tracking)
for uid in cur_uids:
self.eviction_policy.mark_used(uid)

for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class ServerArgs:
] = None
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_eviction_policy: str = "lru" # Backward compatibility: default to LRU
lora_backend: str = "triton"
max_lora_chunk_size: Optional[int] = 16

Expand Down Expand Up @@ -1795,6 +1796,13 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=ServerArgs.max_loaded_loras,
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
)
parser.add_argument(
"--lora-eviction-policy",
type=str,
default="lru",
choices=["lru", "fifo"],
help="LoRA adapter eviction policy when memory pool is full. 'lru': Least Recently Used (default, better cache efficiency). 'fifo': First-In-First-Out.",
)
parser.add_argument(
"--lora-backend",
type=str,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def __init__(
lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None,
max_loaded_loras: Optional[int] = None,
lora_eviction_policy: str = "lru",
):
self.model_type = model_type
self.is_generation = model_type == "generation"
Expand Down Expand Up @@ -565,6 +566,7 @@ def __init__(
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
max_loaded_loras=max_loaded_loras,
lora_eviction_policy=lora_eviction_policy,
**spec_kwargs,
)

Expand Down
Loading
Loading