Skip to content

Commit b9a405f

Browse files
dsikkaAlvant
authored andcommitted
[Bugfix] Fix PerTensorScaleParameter weight loading for fused models (vllm-project#7376)
Signed-off-by: Alvant <[email protected]>
1 parent 4889398 commit b9a405f

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from vllm.model_executor.layers.quantization.base_config import (
1515
QuantizationConfig, QuantizeMethodBase)
1616
from vllm.model_executor.parameter import (BasevLLMParameter,
17-
PackedvLLMParameter)
17+
PackedvLLMParameter,
18+
PerTensorScaleParameter)
1819
from vllm.model_executor.utils import set_weight_attrs
1920

2021
logger = init_logger(__name__)
@@ -573,11 +574,13 @@ def weight_loader_v2(self,
573574
param: BasevLLMParameter,
574575
loaded_weight: torch.Tensor,
575576
loaded_shard_id: Optional[int] = None):
576-
param_data = param.data
577577
if loaded_shard_id is None:
578-
if param.output_dim is None:
579-
assert param_data.shape == loaded_weight.shape
580-
param_data.copy_(loaded_weight)
578+
if isinstance(param, PerTensorScaleParameter):
579+
param.load_merged_column_weight(loaded_weight=loaded_weight,
580+
shard_id=0)
581+
return
582+
elif type(param) is BasevLLMParameter:
583+
param.load_merged_column_weight(loaded_weight=loaded_weight)
581584
return
582585
self._load_fused_module_from_checkpoint(param, loaded_weight)
583586
return
@@ -720,11 +723,13 @@ def weight_loader_v2(self,
720723
param: BasevLLMParameter,
721724
loaded_weight: torch.Tensor,
722725
loaded_shard_id: Optional[str] = None):
723-
param_data = param.data
724726
if loaded_shard_id is None: # special case for certain models
725-
if param.output_dim is None:
726-
assert param_data.shape == loaded_weight.shape
727-
param_data.copy_(loaded_weight)
727+
if isinstance(param, PerTensorScaleParameter):
728+
param.load_merged_column_weight(loaded_weight=loaded_weight,
729+
shard_id=0)
730+
return
731+
elif type(param) is BasevLLMParameter:
732+
param.load_merged_column_weight(loaded_weight=loaded_weight)
728733
return
729734
self._load_fused_module_from_checkpoint(param, loaded_weight)
730735
return

0 commit comments

Comments
 (0)