Skip to content

Commit 72e0b05

Browse files
hmellorYuqi Zhang
authored andcommitted
Update deprecated type hinting in vllm/lora (vllm-project#18128)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Yuqi Zhang <[email protected]>
1 parent bed3f0c commit 72e0b05

19 files changed

+245
-251
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ exclude = [
7878
"vllm/distributed/**/*.py" = ["UP006", "UP035"]
7979
"vllm/engine/**/*.py" = ["UP006", "UP035"]
8080
"vllm/executor/**/*.py" = ["UP006", "UP035"]
81-
"vllm/lora/**/*.py" = ["UP006", "UP035"]
8281
"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"]
8382
"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"]
8483
"vllm/platforms/**/*.py" = ["UP006", "UP035"]

vllm/lora/fully_sharded_layers.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
# pylint: disable=unused-argument
4-
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
4+
from typing import TYPE_CHECKING, Optional, Union, cast
55

66
import torch
77
import torch.nn as nn
@@ -118,7 +118,7 @@ def can_replace_layer(
118118
cls,
119119
source_layer: nn.Module,
120120
lora_config: LoRAConfig,
121-
packed_modules_list: List,
121+
packed_modules_list: list,
122122
model_config: Optional[PretrainedConfig],
123123
) -> bool:
124124
# specifying kwargs so they can be easily accessed in decorator
@@ -141,8 +141,8 @@ class MergedColumnParallelLinearWithShardedLoRA(
141141
"""
142142

143143
def slice_lora_a(
144-
self, lora_a: List[Union[torch.Tensor, None]]
145-
) -> List[Union[torch.Tensor, None]]:
144+
self, lora_a: list[Union[torch.Tensor, None]]
145+
) -> list[Union[torch.Tensor, None]]:
146146
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
147147
output_shard_size = self.lora_a_stacked[0].shape[2]
148148
output_start_idx = self.tp_rank * output_shard_size
@@ -165,7 +165,7 @@ def can_replace_layer(
165165
cls,
166166
source_layer: nn.Module,
167167
lora_config: LoRAConfig,
168-
packed_modules_list: List,
168+
packed_modules_list: list,
169169
model_config: Optional[PretrainedConfig],
170170
) -> bool:
171171
# specifying kwargs so they can be easily accessed in decorator
@@ -201,7 +201,7 @@ def apply(self,
201201
@classmethod
202202
@_fully_sharded_can_replace
203203
def can_replace_layer(cls, source_layer: nn.Module,
204-
lora_config: LoRAConfig, packed_modules_list: List,
204+
lora_config: LoRAConfig, packed_modules_list: list,
205205
model_config: Optional[PretrainedConfig]) -> bool:
206206
# specifying kwargs so they can be easily accessed in decorator
207207
return super().can_replace_layer(
@@ -222,8 +222,8 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
222222
"""
223223

224224
def slice_lora_a(
225-
self, lora_a: List[Union[torch.Tensor, None]]
226-
) -> List[Union[torch.Tensor, None]]:
225+
self, lora_a: list[Union[torch.Tensor, None]]
226+
) -> list[Union[torch.Tensor, None]]:
227227
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
228228
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
229229
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
@@ -248,7 +248,7 @@ def can_replace_layer(
248248
cls,
249249
source_layer: nn.Module,
250250
lora_config: LoRAConfig,
251-
packed_modules_list: List,
251+
packed_modules_list: list,
252252
model_config: Optional[PretrainedConfig],
253253
) -> bool:
254254
# specifying kwargs so they can be easily accessed in decorator
@@ -281,7 +281,7 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
281281
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
282282
if bias is None:
283283
return bias
284-
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
284+
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
285285
self.lora_bias_stacked)
286286
shard_size = self.lora_bias_stacked[0].shape[2]
287287
start_idx = self.tp_rank * shard_size
@@ -341,7 +341,7 @@ def can_replace_layer(
341341
cls,
342342
source_layer: nn.Module,
343343
lora_config: LoRAConfig,
344-
packed_modules_list: List,
344+
packed_modules_list: list,
345345
model_config: Optional[PretrainedConfig],
346346
) -> bool:
347347
# specifying kwargs so they can be easily accessed in decorator

vllm/lora/layers.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# pylint: disable=unused-argument
44
import math
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
6+
from typing import TYPE_CHECKING, Optional, Union, cast
77

88
import torch
99
import torch.nn as nn
@@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping):
8282
class BaseLayerWithLoRA(nn.Module):
8383

8484
def slice_lora_a(
85-
self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
86-
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
85+
self, lora_a: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
86+
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
8787
"""Slice lora a if splitting for tensor parallelism."""
8888
...
8989

9090
def slice_lora_b(
91-
self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]]
92-
) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]:
91+
self, lora_b: Union[torch.Tensor, list[Union[torch.Tensor, None]]]
92+
) -> Union[torch.Tensor, list[Union[torch.Tensor, None]]]:
9393
"""Slice lora b if splitting with tensor parallelism."""
9494
...
9595

@@ -128,7 +128,7 @@ def can_replace_layer(
128128
cls,
129129
source_layer: nn.Module,
130130
lora_config: LoRAConfig,
131-
packed_modules_list: List,
131+
packed_modules_list: list,
132132
model_config: Optional[PretrainedConfig],
133133
) -> bool:
134134
"""Returns True if the layer can be replaced by this LoRA layer."""
@@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
140140
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
141141
super().__init__()
142142
self.base_layer = base_layer
143-
self.embeddings_slice: Optional[Tuple[int, int]]
143+
self.embeddings_slice: Optional[tuple[int, int]]
144144
self.embeddings_weights: Optional[torch.Tensor]
145145

146146
def create_lora_weights(
@@ -279,7 +279,7 @@ def can_replace_layer(
279279
cls,
280280
source_layer: nn.Module,
281281
lora_config: LoRAConfig,
282-
packed_modules_list: List,
282+
packed_modules_list: list,
283283
model_config: Optional[PretrainedConfig],
284284
) -> bool:
285285
return type(source_layer) is VocabParallelEmbedding
@@ -296,9 +296,9 @@ def __init__(self, base_layer: LinearBase):
296296
self.base_layer = base_layer
297297
self.input_size = self.base_layer.input_size
298298
self.device = _get_lora_device(self.base_layer)
299-
self.lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]] = None
299+
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
300300

301-
self.output_slices: Tuple[int, ...]
301+
self.output_slices: tuple[int, ...]
302302
self.tp_size: int
303303
self.output_size: int
304304
self.n_slices: int
@@ -365,7 +365,7 @@ def reset_lora(self, index: int):
365365
self.lora_b_stacked[s_index][index] = 0
366366
if self.lora_config.bias_enabled:
367367
# Make mypy happy
368-
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
368+
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
369369
self.lora_bias_stacked)
370370
self.lora_bias_stacked[s_index][index] = 0
371371

@@ -399,7 +399,7 @@ def set_lora(
399399
lora_b.T, non_blocking=True)
400400
if lora_bias is not None:
401401

402-
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
402+
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
403403
self.lora_bias_stacked)
404404
assert len(self.lora_bias_stacked)
405405
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
@@ -497,7 +497,7 @@ def can_replace_layer(
497497
cls,
498498
source_layer: nn.Module,
499499
lora_config: LoRAConfig,
500-
packed_modules_list: List,
500+
packed_modules_list: list,
501501
model_config: Optional[PretrainedConfig],
502502
) -> bool:
503503
return type(source_layer) is ReplicatedLinear
@@ -597,7 +597,7 @@ def can_replace_layer(
597597
cls,
598598
source_layer: nn.Module,
599599
lora_config: LoRAConfig,
600-
packed_modules_list: List,
600+
packed_modules_list: list,
601601
model_config: Optional[PretrainedConfig],
602602
) -> bool:
603603
return type(source_layer) is ColumnParallelLinear or (
@@ -674,13 +674,13 @@ def create_lora_weights(
674674
) for output_size in self.output_slices)
675675

676676
def slice_lora_a(
677-
self, lora_a: List[Union[torch.Tensor, None]]
678-
) -> List[Union[torch.Tensor, None]]:
677+
self, lora_a: list[Union[torch.Tensor, None]]
678+
) -> list[Union[torch.Tensor, None]]:
679679
return lora_a
680680

681681
def slice_lora_b(
682-
self, lora_b: List[Union[torch.Tensor, None]]
683-
) -> List[Union[torch.Tensor, None]]:
682+
self, lora_b: list[Union[torch.Tensor, None]]
683+
) -> list[Union[torch.Tensor, None]]:
684684
for i, (shard_id, shard_size) in enumerate(
685685
zip(self.output_ids, self.output_slices)):
686686
if (lora_b_i := lora_b[i]) is not None:
@@ -689,8 +689,8 @@ def slice_lora_b(
689689
return lora_b
690690

691691
def slice_bias(
692-
self, bias: List[Union[torch.Tensor,
693-
None]]) -> List[Union[torch.Tensor, None]]:
692+
self, bias: list[Union[torch.Tensor,
693+
None]]) -> list[Union[torch.Tensor, None]]:
694694
for i, (shard_id, shard_size) in enumerate(
695695
zip(self.output_ids, self.output_slices)):
696696
if (bias_i := bias[i]) is not None:
@@ -725,7 +725,7 @@ def set_lora(
725725
lora_b_i.T, non_blocking=True)
726726

727727
if lora_bias is not None:
728-
self.lora_bias_stacked = cast(Tuple[torch.Tensor, ...],
728+
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
729729
self.lora_bias_stacked)
730730
for i in range(self.n_slices):
731731
if (lora_bias_i := lora_bias[i]) is not None:
@@ -740,7 +740,7 @@ def can_replace_layer(
740740
cls,
741741
source_layer: nn.Module,
742742
lora_config: LoRAConfig,
743-
packed_modules_list: List,
743+
packed_modules_list: list,
744744
model_config: Optional[PretrainedConfig],
745745
) -> bool:
746746
return (type(source_layer) is MergedColumnParallelLinear
@@ -809,7 +809,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
809809
@classmethod
810810
@_not_fully_sharded_can_replace
811811
def can_replace_layer(cls, source_layer: nn.Module,
812-
lora_config: LoRAConfig, packed_modules_list: List,
812+
lora_config: LoRAConfig, packed_modules_list: list,
813813
model_config: Optional[PretrainedConfig]) -> bool:
814814
return type(source_layer) is QKVParallelLinear and len(
815815
packed_modules_list) == 1
@@ -869,7 +869,7 @@ def can_replace_layer(
869869
cls,
870870
source_layer: nn.Module,
871871
lora_config: LoRAConfig,
872-
packed_modules_list: List,
872+
packed_modules_list: list,
873873
model_config: Optional[PretrainedConfig],
874874
) -> bool:
875875
return (type(source_layer) is QKVParallelLinear
@@ -923,7 +923,7 @@ def forward(
923923
- output
924924
- bias
925925
"""
926-
# Set up backprop all-reduce.
926+
# set up backprop all-reduce.
927927
if self.base_layer.input_is_parallel:
928928
input_parallel = input_
929929
else:
@@ -958,7 +958,7 @@ def can_replace_layer(
958958
cls,
959959
source_layer: nn.Module,
960960
lora_config: LoRAConfig,
961-
packed_modules_list: List,
961+
packed_modules_list: list,
962962
model_config: Optional[PretrainedConfig],
963963
) -> bool:
964964
return type(source_layer) is RowParallelLinear
@@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
981981

982982
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
983983
dtype: torch.dtype, device: torch.device,
984-
sharded_to_full_mapping: Optional[List[int]]) -> None:
984+
sharded_to_full_mapping: Optional[list[int]]) -> None:
985985
super().__init__()
986986
self.base_layer = base_layer
987987
self.hidden_size = hidden_size
@@ -1189,7 +1189,7 @@ def can_replace_layer(
11891189
cls,
11901190
source_layer: nn.Module,
11911191
lora_config: LoRAConfig,
1192-
packed_modules_list: List,
1192+
packed_modules_list: list,
11931193
model_config: Optional[PretrainedConfig],
11941194
) -> bool:
11951195
# Special handling for the LogitsProcessor.
@@ -1256,7 +1256,7 @@ def forward(
12561256
positions: torch.Tensor,
12571257
query: torch.Tensor,
12581258
key: torch.Tensor,
1259-
) -> Tuple[torch.Tensor, torch.Tensor]:
1259+
) -> tuple[torch.Tensor, torch.Tensor]:
12601260
return self.base_layer(
12611261
positions,
12621262
query,
@@ -1265,15 +1265,15 @@ def forward(
12651265
)
12661266

12671267
@property
1268-
def scaling_factor_to_offset(self) -> Dict[float, int]:
1268+
def scaling_factor_to_offset(self) -> dict[float, int]:
12691269
return self.base_layer.scaling_factor_to_offset
12701270

12711271
@classmethod
12721272
def can_replace_layer(
12731273
cls,
12741274
source_layer: nn.Module,
12751275
lora_config: LoRAConfig,
1276-
packed_modules_list: List,
1276+
packed_modules_list: list,
12771277
model_config: Optional[PretrainedConfig],
12781278
) -> bool:
12791279
"""Returns True if the layer can be replaced by this LoRA layer."""

vllm/lora/lora.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import List, Optional
4-
from typing import Sequence as GenericSequence
3+
from collections.abc import Sequence as GenericSequence
4+
from typing import Optional
55

66
import torch
77
import torch.types
@@ -125,11 +125,11 @@ def __init__(
125125
self,
126126
module_name: str,
127127
rank: int,
128-
lora_alphas: List[Optional[int]],
129-
lora_a: List[Optional[torch.Tensor]],
130-
lora_b: List[Optional[torch.Tensor]],
131-
bias: Optional[List[Optional[torch.Tensor]]] = None,
132-
scaling: Optional[List[float]] = None,
128+
lora_alphas: list[Optional[int]],
129+
lora_a: list[Optional[torch.Tensor]],
130+
lora_b: list[Optional[torch.Tensor]],
131+
bias: Optional[list[Optional[torch.Tensor]]] = None,
132+
scaling: Optional[list[float]] = None,
133133
) -> None:
134134
super().__init__(
135135
module_name=module_name,

0 commit comments

Comments
 (0)