46
46
from vllm .sequence import IntermediateTensors
47
47
48
48
from .interfaces import SupportsQuant , SupportsV0Only
49
- from .utils import maybe_prefix
49
+ from .utils import AutoWeightsLoader , WeightsMapper , maybe_prefix
50
50
51
51
logger = logging .get_logger (__name__ )
52
52
@@ -700,7 +700,8 @@ def forward(
700
700
701
701
class BartModel (nn .Module , SupportsQuant ):
702
702
_tied_weights_keys = [
703
- "encoder.embed_tokens.weight" , "decoder.embed_tokens.weight"
703
+ "encoder.embed_tokens.weight" ,
704
+ "decoder.embed_tokens.weight" ,
704
705
]
705
706
706
707
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
@@ -763,10 +764,54 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
763
764
764
765
return decoder_outputs
765
766
767
+ def load_weights (self , weights : Iterable [tuple [str ,
768
+ torch .Tensor ]]) -> set [str ]:
769
+ stacked_params_mapping = [
770
+ # (param_name, shard_name, shard_id)
771
+ ("qkv_proj" , "q_proj" , "q" ),
772
+ ("qkv_proj" , "k_proj" , "k" ),
773
+ ("qkv_proj" , "v_proj" , "v" ),
774
+ ]
775
+
776
+ other_weights = []
777
+ loaded_stacked_params = []
778
+ model_params_dict = dict (self .named_parameters ())
779
+
780
+ for name , loaded_weight in weights :
781
+ for (param_name , weight_name , shard_id ) in stacked_params_mapping :
782
+ if weight_name not in name :
783
+ continue
784
+ name = name .replace (weight_name , param_name )
785
+ if name not in model_params_dict :
786
+ continue
787
+ param = model_params_dict [name ]
788
+ weight_loader = param .weight_loader
789
+ weight_loader (param , loaded_weight , shard_id )
790
+ loaded_stacked_params .append (name )
791
+ break
792
+ else :
793
+ if name in model_params_dict :
794
+ other_weights .append ((name , loaded_weight ))
795
+
796
+ loader = AutoWeightsLoader (self )
797
+ loaded_params = loader .load_weights (other_weights )
798
+ loaded_params .update (loaded_stacked_params )
799
+ return loaded_params
800
+
766
801
767
802
class BartForConditionalGeneration (nn .Module , SupportsV0Only , SupportsQuant ):
768
- packed_modules_mapping = {"qkv_proj" : ["q_proj" , "k_proj" , "v_proj" ]}
769
- base_model_prefix = "model"
803
+ hf_to_vllm_mapper = WeightsMapper (
804
+ orig_to_new_prefix = {
805
+ "decoder." : "model.decoder." ,
806
+ "encoder." : "model.encoder." ,
807
+ "shared." : "model.shared."
808
+ },
809
+ orig_to_new_substr = {
810
+ "beta" : "bias" ,
811
+ "gamma" : "weight" ,
812
+ "LayerNorm" : "layernorm" ,
813
+ },
814
+ )
770
815
771
816
def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
772
817
@@ -789,7 +834,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
789
834
self .lm_head = BartParallelLMHead (config .vocab_size ,
790
835
config .d_model ,
791
836
embed_scale = embed_scale )
792
-
793
837
self .logits_processor = LogitsProcessor (self .unpadded_vocab_size ,
794
838
config .vocab_size )
795
839
@@ -828,111 +872,37 @@ def compute_logits(
828
872
sampling_metadata )
829
873
return logits
830
874
831
- stacked_params_mapping = {
832
- "q_proj" : {
833
- "param_name" : "qkv_proj" ,
834
- "shard_id" : "q" ,
835
- },
836
- "k_proj" : {
837
- "param_name" : "qkv_proj" ,
838
- "shard_id" : "k" ,
839
- },
840
- "v_proj" : {
841
- "param_name" : "qkv_proj" ,
842
- "shard_id" : "v" ,
843
- },
844
- }
845
-
846
- params_mapping = {
847
- "beta" : "bias" ,
848
- "gamma" : "weight" ,
849
- "LayerNorm" : "layernorm" ,
850
- }
851
-
852
- def _rename_key (self , key : str ):
853
- prefix = f"{ self .base_model_prefix } ."
854
- key = key [len (prefix ):] if key .startswith (prefix ) else key
855
-
856
- for src , dst in self .params_mapping .items ():
857
- key = key .replace (src , dst )
858
-
859
- return key
860
-
861
- def _rename_stacked_param (
862
- self ,
863
- name : str ,
864
- ) -> tuple [str , Optional [str ]]:
865
- for key , mapping in self .stacked_params_mapping .items ():
866
- if key in name :
867
- name = name .replace (key , mapping ["param_name" ])
868
- return name , mapping ["shard_id" ]
869
- return name , None
870
-
871
- def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
872
-
873
- model_params_dict = dict (self .model .named_parameters ())
874
- top_params_dict = dict (self .named_parameters ())
875
-
875
+ def load_weights (self , weights : Iterable [tuple [str ,
876
+ torch .Tensor ]]) -> set [str ]:
876
877
weights_tuple_list = list (weights )
877
878
878
879
shared_embedding_weight = None
879
- shared_embedding_shard_id = None
880
-
881
880
for name , loaded_weight in weights_tuple_list :
882
-
883
- name = self ._rename_key (name )
884
- name , shard_id = self ._rename_stacked_param (name )
885
-
886
881
if ('shared.weight' in name
887
882
or 'encoder.embed_tokens.weight' in name
888
883
or 'decoder.embed_tokens.weight' in name
889
884
or 'lm_head.weight' in name ):
890
885
assert shared_embedding_weight is None , (
891
886
"Conflicting embedding weights." )
892
887
shared_embedding_weight = loaded_weight
893
- shared_embedding_shard_id = shard_id
894
- else :
895
- # Skip the specific downstream task weight.
896
- if name .startswith ('cls.' ):
897
- continue
898
- # use Pooler instead.
899
- if name .startswith ('pooler.' ):
900
- continue
901
- # Skip loading extra bias for GPTQ models.
902
- if name .endswith (".bias" ) and name not in model_params_dict :
903
- continue
904
888
905
- param = model_params_dict [name ]
906
- weight_loader = getattr (param , "weight_loader" ,
907
- default_weight_loader )
908
- if shard_id :
909
- weight_loader (param , loaded_weight , shard_id )
910
- else :
911
- weight_loader (param , loaded_weight )
912
-
913
- # Assign shared weight values
914
- encoder_in_param = model_params_dict ['encoder.embed_tokens.weight' ]
915
- encoder_in_weight_loader = getattr (encoder_in_param , "weight_loader" ,
916
- default_weight_loader )
917
-
918
- decoder_in_param = model_params_dict ['decoder.embed_tokens.weight' ]
919
- decoder_in_weight_loader = getattr (decoder_in_param , "weight_loader" ,
920
- default_weight_loader )
921
-
922
- lm_head_in_param = top_params_dict ['lm_head.weight' ]
923
- lm_head_in_weight_loader = getattr (lm_head_in_param , "weight_loader" ,
924
- default_weight_loader )
925
-
926
- assert shared_embedding_weight is not None
927
-
928
- if shared_embedding_shard_id :
929
- encoder_in_weight_loader (encoder_in_param , shared_embedding_weight ,
930
- shared_embedding_shard_id )
931
- decoder_in_weight_loader (decoder_in_param , shared_embedding_weight ,
932
- shared_embedding_shard_id )
933
- lm_head_in_weight_loader (lm_head_in_param , shared_embedding_weight ,
934
- shared_embedding_shard_id )
935
- else :
936
- encoder_in_weight_loader (encoder_in_param , shared_embedding_weight )
937
- decoder_in_weight_loader (decoder_in_param , shared_embedding_weight )
938
- lm_head_in_weight_loader (lm_head_in_param , shared_embedding_weight )
889
+ loader = AutoWeightsLoader (
890
+ self ,
891
+ skip_prefixes = (["cls." , "pooler." ]),
892
+ )
893
+ loaded_params = loader .load_weights (weights_tuple_list ,
894
+ mapper = self .hf_to_vllm_mapper )
895
+
896
+ if shared_embedding_weight is not None :
897
+ weight_loader = getattr (self .lm_head .weight , "weight_loader" ,
898
+ default_weight_loader )
899
+ weight_loader (self .lm_head .weight , shared_embedding_weight )
900
+
901
+ self .model .encoder .embed_tokens .weight = self .lm_head .weight
902
+ self .model .decoder .embed_tokens .weight = self .lm_head .weight
903
+ loaded_params .update ({
904
+ 'model.encoder.embed_tokens.weight' , 'lm_head.weight' ,
905
+ 'model.decoder.embed_tokens.weight'
906
+ })
907
+
908
+ return loaded_params
0 commit comments