3
3
# pylint: disable=unused-argument
4
4
import math
5
5
from dataclasses import dataclass
6
- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Union , cast
6
+ from typing import TYPE_CHECKING , Optional , Union , cast
7
7
8
8
import torch
9
9
import torch .nn as nn
@@ -82,14 +82,14 @@ class LoRAMapping(AdapterMapping):
82
82
class BaseLayerWithLoRA (nn .Module ):
83
83
84
84
def slice_lora_a (
85
- self , lora_a : Union [torch .Tensor , List [Union [torch .Tensor , None ]]]
86
- ) -> Union [torch .Tensor , List [Union [torch .Tensor , None ]]]:
85
+ self , lora_a : Union [torch .Tensor , list [Union [torch .Tensor , None ]]]
86
+ ) -> Union [torch .Tensor , list [Union [torch .Tensor , None ]]]:
87
87
"""Slice lora a if splitting for tensor parallelism."""
88
88
...
89
89
90
90
def slice_lora_b (
91
- self , lora_b : Union [torch .Tensor , List [Union [torch .Tensor , None ]]]
92
- ) -> Union [torch .Tensor , List [Union [torch .Tensor , None ]]]:
91
+ self , lora_b : Union [torch .Tensor , list [Union [torch .Tensor , None ]]]
92
+ ) -> Union [torch .Tensor , list [Union [torch .Tensor , None ]]]:
93
93
"""Slice lora b if splitting with tensor parallelism."""
94
94
...
95
95
@@ -128,7 +128,7 @@ def can_replace_layer(
128
128
cls ,
129
129
source_layer : nn .Module ,
130
130
lora_config : LoRAConfig ,
131
- packed_modules_list : List ,
131
+ packed_modules_list : list ,
132
132
model_config : Optional [PretrainedConfig ],
133
133
) -> bool :
134
134
"""Returns True if the layer can be replaced by this LoRA layer."""
@@ -140,7 +140,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
140
140
def __init__ (self , base_layer : VocabParallelEmbedding ) -> None :
141
141
super ().__init__ ()
142
142
self .base_layer = base_layer
143
- self .embeddings_slice : Optional [Tuple [int , int ]]
143
+ self .embeddings_slice : Optional [tuple [int , int ]]
144
144
self .embeddings_weights : Optional [torch .Tensor ]
145
145
146
146
def create_lora_weights (
@@ -279,7 +279,7 @@ def can_replace_layer(
279
279
cls ,
280
280
source_layer : nn .Module ,
281
281
lora_config : LoRAConfig ,
282
- packed_modules_list : List ,
282
+ packed_modules_list : list ,
283
283
model_config : Optional [PretrainedConfig ],
284
284
) -> bool :
285
285
return type (source_layer ) is VocabParallelEmbedding
@@ -296,9 +296,9 @@ def __init__(self, base_layer: LinearBase):
296
296
self .base_layer = base_layer
297
297
self .input_size = self .base_layer .input_size
298
298
self .device = _get_lora_device (self .base_layer )
299
- self .lora_bias_stacked : Optional [Tuple [torch .Tensor , ...]] = None
299
+ self .lora_bias_stacked : Optional [tuple [torch .Tensor , ...]] = None
300
300
301
- self .output_slices : Tuple [int , ...]
301
+ self .output_slices : tuple [int , ...]
302
302
self .tp_size : int
303
303
self .output_size : int
304
304
self .n_slices : int
@@ -365,7 +365,7 @@ def reset_lora(self, index: int):
365
365
self .lora_b_stacked [s_index ][index ] = 0
366
366
if self .lora_config .bias_enabled :
367
367
# Make mypy happy
368
- self .lora_bias_stacked = cast (Tuple [torch .Tensor , ...],
368
+ self .lora_bias_stacked = cast (tuple [torch .Tensor , ...],
369
369
self .lora_bias_stacked )
370
370
self .lora_bias_stacked [s_index ][index ] = 0
371
371
@@ -399,7 +399,7 @@ def set_lora(
399
399
lora_b .T , non_blocking = True )
400
400
if lora_bias is not None :
401
401
402
- self .lora_bias_stacked = cast (Tuple [torch .Tensor , ...],
402
+ self .lora_bias_stacked = cast (tuple [torch .Tensor , ...],
403
403
self .lora_bias_stacked )
404
404
assert len (self .lora_bias_stacked )
405
405
self .lora_bias_stacked [0 ][index , 0 , :lora_bias .shape [0 ]].copy_ (
@@ -497,7 +497,7 @@ def can_replace_layer(
497
497
cls ,
498
498
source_layer : nn .Module ,
499
499
lora_config : LoRAConfig ,
500
- packed_modules_list : List ,
500
+ packed_modules_list : list ,
501
501
model_config : Optional [PretrainedConfig ],
502
502
) -> bool :
503
503
return type (source_layer ) is ReplicatedLinear
@@ -597,7 +597,7 @@ def can_replace_layer(
597
597
cls ,
598
598
source_layer : nn .Module ,
599
599
lora_config : LoRAConfig ,
600
- packed_modules_list : List ,
600
+ packed_modules_list : list ,
601
601
model_config : Optional [PretrainedConfig ],
602
602
) -> bool :
603
603
return type (source_layer ) is ColumnParallelLinear or (
@@ -674,13 +674,13 @@ def create_lora_weights(
674
674
) for output_size in self .output_slices )
675
675
676
676
def slice_lora_a (
677
- self , lora_a : List [Union [torch .Tensor , None ]]
678
- ) -> List [Union [torch .Tensor , None ]]:
677
+ self , lora_a : list [Union [torch .Tensor , None ]]
678
+ ) -> list [Union [torch .Tensor , None ]]:
679
679
return lora_a
680
680
681
681
def slice_lora_b (
682
- self , lora_b : List [Union [torch .Tensor , None ]]
683
- ) -> List [Union [torch .Tensor , None ]]:
682
+ self , lora_b : list [Union [torch .Tensor , None ]]
683
+ ) -> list [Union [torch .Tensor , None ]]:
684
684
for i , (shard_id , shard_size ) in enumerate (
685
685
zip (self .output_ids , self .output_slices )):
686
686
if (lora_b_i := lora_b [i ]) is not None :
@@ -689,8 +689,8 @@ def slice_lora_b(
689
689
return lora_b
690
690
691
691
def slice_bias (
692
- self , bias : List [Union [torch .Tensor ,
693
- None ]]) -> List [Union [torch .Tensor , None ]]:
692
+ self , bias : list [Union [torch .Tensor ,
693
+ None ]]) -> list [Union [torch .Tensor , None ]]:
694
694
for i , (shard_id , shard_size ) in enumerate (
695
695
zip (self .output_ids , self .output_slices )):
696
696
if (bias_i := bias [i ]) is not None :
@@ -725,7 +725,7 @@ def set_lora(
725
725
lora_b_i .T , non_blocking = True )
726
726
727
727
if lora_bias is not None :
728
- self .lora_bias_stacked = cast (Tuple [torch .Tensor , ...],
728
+ self .lora_bias_stacked = cast (tuple [torch .Tensor , ...],
729
729
self .lora_bias_stacked )
730
730
for i in range (self .n_slices ):
731
731
if (lora_bias_i := lora_bias [i ]) is not None :
@@ -740,7 +740,7 @@ def can_replace_layer(
740
740
cls ,
741
741
source_layer : nn .Module ,
742
742
lora_config : LoRAConfig ,
743
- packed_modules_list : List ,
743
+ packed_modules_list : list ,
744
744
model_config : Optional [PretrainedConfig ],
745
745
) -> bool :
746
746
return (type (source_layer ) is MergedColumnParallelLinear
@@ -809,7 +809,7 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
809
809
@classmethod
810
810
@_not_fully_sharded_can_replace
811
811
def can_replace_layer (cls , source_layer : nn .Module ,
812
- lora_config : LoRAConfig , packed_modules_list : List ,
812
+ lora_config : LoRAConfig , packed_modules_list : list ,
813
813
model_config : Optional [PretrainedConfig ]) -> bool :
814
814
return type (source_layer ) is QKVParallelLinear and len (
815
815
packed_modules_list ) == 1
@@ -869,7 +869,7 @@ def can_replace_layer(
869
869
cls ,
870
870
source_layer : nn .Module ,
871
871
lora_config : LoRAConfig ,
872
- packed_modules_list : List ,
872
+ packed_modules_list : list ,
873
873
model_config : Optional [PretrainedConfig ],
874
874
) -> bool :
875
875
return (type (source_layer ) is QKVParallelLinear
@@ -923,7 +923,7 @@ def forward(
923
923
- output
924
924
- bias
925
925
"""
926
- # Set up backprop all-reduce.
926
+ # set up backprop all-reduce.
927
927
if self .base_layer .input_is_parallel :
928
928
input_parallel = input_
929
929
else :
@@ -958,7 +958,7 @@ def can_replace_layer(
958
958
cls ,
959
959
source_layer : nn .Module ,
960
960
lora_config : LoRAConfig ,
961
- packed_modules_list : List ,
961
+ packed_modules_list : list ,
962
962
model_config : Optional [PretrainedConfig ],
963
963
) -> bool :
964
964
return type (source_layer ) is RowParallelLinear
@@ -981,7 +981,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
981
981
982
982
def __init__ (self , base_layer : LogitsProcessor , hidden_size : int ,
983
983
dtype : torch .dtype , device : torch .device ,
984
- sharded_to_full_mapping : Optional [List [int ]]) -> None :
984
+ sharded_to_full_mapping : Optional [list [int ]]) -> None :
985
985
super ().__init__ ()
986
986
self .base_layer = base_layer
987
987
self .hidden_size = hidden_size
@@ -1189,7 +1189,7 @@ def can_replace_layer(
1189
1189
cls ,
1190
1190
source_layer : nn .Module ,
1191
1191
lora_config : LoRAConfig ,
1192
- packed_modules_list : List ,
1192
+ packed_modules_list : list ,
1193
1193
model_config : Optional [PretrainedConfig ],
1194
1194
) -> bool :
1195
1195
# Special handling for the LogitsProcessor.
@@ -1256,7 +1256,7 @@ def forward(
1256
1256
positions : torch .Tensor ,
1257
1257
query : torch .Tensor ,
1258
1258
key : torch .Tensor ,
1259
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
1259
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
1260
1260
return self .base_layer (
1261
1261
positions ,
1262
1262
query ,
@@ -1265,15 +1265,15 @@ def forward(
1265
1265
)
1266
1266
1267
1267
@property
1268
- def scaling_factor_to_offset (self ) -> Dict [float , int ]:
1268
+ def scaling_factor_to_offset (self ) -> dict [float , int ]:
1269
1269
return self .base_layer .scaling_factor_to_offset
1270
1270
1271
1271
@classmethod
1272
1272
def can_replace_layer (
1273
1273
cls ,
1274
1274
source_layer : nn .Module ,
1275
1275
lora_config : LoRAConfig ,
1276
- packed_modules_list : List ,
1276
+ packed_modules_list : list ,
1277
1277
model_config : Optional [PretrainedConfig ],
1278
1278
) -> bool :
1279
1279
"""Returns True if the layer can be replaced by this LoRA layer."""
0 commit comments