Skip to content

Update deprecated type hinting in vllm/adapter_commons #18073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ exclude = [
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
# Python 3.8 typing. TODO: Remove these excludes after v1.0.0
"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"]
"vllm/attention/**/*.py" = ["UP006", "UP035"]
"vllm/compilation/**/*.py" = ["UP006", "UP035"]
"vllm/core/**/*.py" = ["UP006", "UP035"]
Expand Down
5 changes: 2 additions & 3 deletions vllm/adapter_commons/layers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass
from typing import Tuple


@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
index_mapping: tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]
prompt_mapping: tuple[int, ...]

def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
Expand Down
8 changes: 4 additions & 4 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, TypeVar
from typing import Any, Callable, Optional, TypeVar

from torch import nn

Expand Down Expand Up @@ -49,9 +49,9 @@ def __init__(
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
self._registered_adapters: dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self._active_adapters: dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None

Expand Down Expand Up @@ -97,7 +97,7 @@ def get_adapter(self, adapter_id: int) -> Optional[Any]:
raise NotImplementedError

@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
def list_adapters(self) -> dict[int, Any]:
raise NotImplementedError

@abstractmethod
Expand Down
18 changes: 9 additions & 9 deletions vllm/adapter_commons/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Callable, Dict, Optional, Set
from typing import Any, Callable, Optional


## model functions
def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
Expand All @@ -13,7 +13,7 @@ def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None],
return False


def add_adapter(adapter: Any, registered_adapters: Dict[int, Any],
def add_adapter(adapter: Any, registered_adapters: dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
Expand All @@ -32,23 +32,23 @@ def set_adapter_mapping(mapping: Any, last_mapping: Any,
return last_mapping


def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any],
def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))


def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]:
return dict(registered_adapters)


def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]:
registered_adapters: dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id)


## worker functions
def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any],
def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
Expand All @@ -66,7 +66,7 @@ def add_adapter_worker(adapter_request: Any, list_adapters_func,
return loaded


def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
Expand All @@ -88,5 +88,5 @@ def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func,
add_adapter_func(models_map[adapter_id])


def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]:
def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]:
return set(adapter_manager_list_adapters_func())
6 changes: 3 additions & 3 deletions vllm/adapter_commons/worker_manager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Any, Optional, Set
from typing import Any, Optional

import torch

Expand All @@ -17,7 +17,7 @@ def is_enabled(self) -> bool:
raise NotImplementedError

@abstractmethod
def set_active_adapters(self, requests: Set[Any],
def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None:
raise NotImplementedError

Expand All @@ -34,5 +34,5 @@ def remove_all_adapters(self) -> None:
raise NotImplementedError

@abstractmethod
def list_adapters(self) -> Set[int]:
def list_adapters(self) -> set[int]:
raise NotImplementedError
Loading