|
14 | 14 | from vllm.model_executor.layers.quantization.base_config import (
|
15 | 15 | QuantizationConfig, QuantizeMethodBase)
|
16 | 16 | from vllm.model_executor.parameter import (BasevLLMParameter,
|
17 |
| - PackedvLLMParameter) |
| 17 | + PackedvLLMParameter, |
| 18 | + PerTensorScaleParameter) |
18 | 19 | from vllm.model_executor.utils import set_weight_attrs
|
19 | 20 |
|
20 | 21 | logger = init_logger(__name__)
|
@@ -573,11 +574,13 @@ def weight_loader_v2(self,
|
573 | 574 | param: BasevLLMParameter,
|
574 | 575 | loaded_weight: torch.Tensor,
|
575 | 576 | loaded_shard_id: Optional[int] = None):
|
576 |
| - param_data = param.data |
577 | 577 | if loaded_shard_id is None:
|
578 |
| - if param.output_dim is None: |
579 |
| - assert param_data.shape == loaded_weight.shape |
580 |
| - param_data.copy_(loaded_weight) |
| 578 | + if isinstance(param, PerTensorScaleParameter): |
| 579 | + param.load_merged_column_weight(loaded_weight=loaded_weight, |
| 580 | + shard_id=0) |
| 581 | + return |
| 582 | + elif type(param) is BasevLLMParameter: |
| 583 | + param.load_merged_column_weight(loaded_weight=loaded_weight) |
581 | 584 | return
|
582 | 585 | self._load_fused_module_from_checkpoint(param, loaded_weight)
|
583 | 586 | return
|
@@ -720,11 +723,13 @@ def weight_loader_v2(self,
|
720 | 723 | param: BasevLLMParameter,
|
721 | 724 | loaded_weight: torch.Tensor,
|
722 | 725 | loaded_shard_id: Optional[str] = None):
|
723 |
| - param_data = param.data |
724 | 726 | if loaded_shard_id is None: # special case for certain models
|
725 |
| - if param.output_dim is None: |
726 |
| - assert param_data.shape == loaded_weight.shape |
727 |
| - param_data.copy_(loaded_weight) |
| 727 | + if isinstance(param, PerTensorScaleParameter): |
| 728 | + param.load_merged_column_weight(loaded_weight=loaded_weight, |
| 729 | + shard_id=0) |
| 730 | + return |
| 731 | + elif type(param) is BasevLLMParameter: |
| 732 | + param.load_merged_column_weight(loaded_weight=loaded_weight) |
728 | 733 | return
|
729 | 734 | self._load_fused_module_from_checkpoint(param, loaded_weight)
|
730 | 735 | return
|
|
0 commit comments