Skip to content

Commit 1462881

Browse files
calvin0327x22x22
authored andcommitted
[Model] use AutoWeightsLoader for bart (vllm-project#18299)
Signed-off-by: calvin chen <[email protected]> Signed-off-by: x22x22 <[email protected]>
1 parent 4b15142 commit 1462881

File tree

1 file changed

+71
-101
lines changed

1 file changed

+71
-101
lines changed

vllm/model_executor/models/bart.py

Lines changed: 71 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from vllm.sequence import IntermediateTensors
4747

4848
from .interfaces import SupportsQuant, SupportsV0Only
49-
from .utils import maybe_prefix
49+
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
5050

5151
logger = logging.get_logger(__name__)
5252

@@ -700,7 +700,8 @@ def forward(
700700

701701
class BartModel(nn.Module, SupportsQuant):
702702
_tied_weights_keys = [
703-
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
703+
"encoder.embed_tokens.weight",
704+
"decoder.embed_tokens.weight",
704705
]
705706

706707
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -763,10 +764,54 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
763764

764765
return decoder_outputs
765766

767+
def load_weights(self, weights: Iterable[tuple[str,
768+
torch.Tensor]]) -> set[str]:
769+
stacked_params_mapping = [
770+
# (param_name, shard_name, shard_id)
771+
("qkv_proj", "q_proj", "q"),
772+
("qkv_proj", "k_proj", "k"),
773+
("qkv_proj", "v_proj", "v"),
774+
]
775+
776+
other_weights = []
777+
loaded_stacked_params = []
778+
model_params_dict = dict(self.named_parameters())
779+
780+
for name, loaded_weight in weights:
781+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
782+
if weight_name not in name:
783+
continue
784+
name = name.replace(weight_name, param_name)
785+
if name not in model_params_dict:
786+
continue
787+
param = model_params_dict[name]
788+
weight_loader = param.weight_loader
789+
weight_loader(param, loaded_weight, shard_id)
790+
loaded_stacked_params.append(name)
791+
break
792+
else:
793+
if name in model_params_dict:
794+
other_weights.append((name, loaded_weight))
795+
796+
loader = AutoWeightsLoader(self)
797+
loaded_params = loader.load_weights(other_weights)
798+
loaded_params.update(loaded_stacked_params)
799+
return loaded_params
800+
766801

767802
class BartForConditionalGeneration(nn.Module, SupportsV0Only, SupportsQuant):
768-
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
769-
base_model_prefix = "model"
803+
hf_to_vllm_mapper = WeightsMapper(
804+
orig_to_new_prefix={
805+
"decoder.": "model.decoder.",
806+
"encoder.": "model.encoder.",
807+
"shared.": "model.shared."
808+
},
809+
orig_to_new_substr={
810+
"beta": "bias",
811+
"gamma": "weight",
812+
"LayerNorm": "layernorm",
813+
},
814+
)
770815

771816
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
772817

@@ -789,7 +834,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
789834
self.lm_head = BartParallelLMHead(config.vocab_size,
790835
config.d_model,
791836
embed_scale=embed_scale)
792-
793837
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
794838
config.vocab_size)
795839

@@ -828,111 +872,37 @@ def compute_logits(
828872
sampling_metadata)
829873
return logits
830874

831-
stacked_params_mapping = {
832-
"q_proj": {
833-
"param_name": "qkv_proj",
834-
"shard_id": "q",
835-
},
836-
"k_proj": {
837-
"param_name": "qkv_proj",
838-
"shard_id": "k",
839-
},
840-
"v_proj": {
841-
"param_name": "qkv_proj",
842-
"shard_id": "v",
843-
},
844-
}
845-
846-
params_mapping = {
847-
"beta": "bias",
848-
"gamma": "weight",
849-
"LayerNorm": "layernorm",
850-
}
851-
852-
def _rename_key(self, key: str):
853-
prefix = f"{self.base_model_prefix}."
854-
key = key[len(prefix):] if key.startswith(prefix) else key
855-
856-
for src, dst in self.params_mapping.items():
857-
key = key.replace(src, dst)
858-
859-
return key
860-
861-
def _rename_stacked_param(
862-
self,
863-
name: str,
864-
) -> tuple[str, Optional[str]]:
865-
for key, mapping in self.stacked_params_mapping.items():
866-
if key in name:
867-
name = name.replace(key, mapping["param_name"])
868-
return name, mapping["shard_id"]
869-
return name, None
870-
871-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
872-
873-
model_params_dict = dict(self.model.named_parameters())
874-
top_params_dict = dict(self.named_parameters())
875-
875+
def load_weights(self, weights: Iterable[tuple[str,
876+
torch.Tensor]]) -> set[str]:
876877
weights_tuple_list = list(weights)
877878

878879
shared_embedding_weight = None
879-
shared_embedding_shard_id = None
880-
881880
for name, loaded_weight in weights_tuple_list:
882-
883-
name = self._rename_key(name)
884-
name, shard_id = self._rename_stacked_param(name)
885-
886881
if ('shared.weight' in name
887882
or 'encoder.embed_tokens.weight' in name
888883
or 'decoder.embed_tokens.weight' in name
889884
or 'lm_head.weight' in name):
890885
assert shared_embedding_weight is None, (
891886
"Conflicting embedding weights.")
892887
shared_embedding_weight = loaded_weight
893-
shared_embedding_shard_id = shard_id
894-
else:
895-
# Skip the specific downstream task weight.
896-
if name.startswith('cls.'):
897-
continue
898-
# use Pooler instead.
899-
if name.startswith('pooler.'):
900-
continue
901-
# Skip loading extra bias for GPTQ models.
902-
if name.endswith(".bias") and name not in model_params_dict:
903-
continue
904888

905-
param = model_params_dict[name]
906-
weight_loader = getattr(param, "weight_loader",
907-
default_weight_loader)
908-
if shard_id:
909-
weight_loader(param, loaded_weight, shard_id)
910-
else:
911-
weight_loader(param, loaded_weight)
912-
913-
# Assign shared weight values
914-
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
915-
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
916-
default_weight_loader)
917-
918-
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
919-
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
920-
default_weight_loader)
921-
922-
lm_head_in_param = top_params_dict['lm_head.weight']
923-
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
924-
default_weight_loader)
925-
926-
assert shared_embedding_weight is not None
927-
928-
if shared_embedding_shard_id:
929-
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
930-
shared_embedding_shard_id)
931-
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
932-
shared_embedding_shard_id)
933-
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
934-
shared_embedding_shard_id)
935-
else:
936-
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
937-
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
938-
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
889+
loader = AutoWeightsLoader(
890+
self,
891+
skip_prefixes=(["cls.", "pooler."]),
892+
)
893+
loaded_params = loader.load_weights(weights_tuple_list,
894+
mapper=self.hf_to_vllm_mapper)
895+
896+
if shared_embedding_weight is not None:
897+
weight_loader = getattr(self.lm_head.weight, "weight_loader",
898+
default_weight_loader)
899+
weight_loader(self.lm_head.weight, shared_embedding_weight)
900+
901+
self.model.encoder.embed_tokens.weight = self.lm_head.weight
902+
self.model.decoder.embed_tokens.weight = self.lm_head.weight
903+
loaded_params.update({
904+
'model.encoder.embed_tokens.weight', 'lm_head.weight',
905+
'model.decoder.embed_tokens.weight'
906+
})
907+
908+
return loaded_params

0 commit comments

Comments
 (0)