@@ -642,7 +642,11 @@ def set_state_dict(self, state_dict):
642642 ffn1_weight_tensor [i ], algo = self .quant_algo
643643 )
644644 ffn1_quanted_weight_list .append (
645- ffn1_quanted_weight_list_i .reshape ([self .transformer_block .config .embed_dim , - 1 ])
645+ ffn1_quanted_weight_list_i .reshape (
646+ [self .transformer_block .embed_dim , self .transformer_block .dim_feedforward * 2 ]
647+ if self .quant_type == "weight_only_int8"
648+ else [self .transformer_block .embed_dim , self .transformer_block .dim_feedforward ]
649+ )
646650 )
647651 ffn1_quanted_weight_scale .append (ffn1_quanted_weight_scale_i )
648652 ffn1_quanted_weight_tensor = paddle .to_tensor (ffn1_quanted_weight_list )
@@ -677,7 +681,11 @@ def set_state_dict(self, state_dict):
677681 ffn2_weight_tensor [i ], algo = self .quant_algo
678682 )
679683 ffn2_quanted_weight_list .append (
680- ffn2_quanted_weight_list_i .reshape ([- 1 , self .transformer_block .config .embed_dim ])
684+ ffn2_quanted_weight_list_i .reshape (
685+ [self .transformer_block .dim_feedforward , self .transformer_block .embed_dim ]
686+ if self .quant_type == "weight_only_int8"
687+ else [self .transformer_block .dim_feedforward , self .transformer_block .embed_dim // 2 ]
688+ )
681689 )
682690 ffn2_quanted_weight_scale .append (ffn2_quanted_weight_scale_i )
683691 ffn2_quanted_weight_tensor = paddle .to_tensor (ffn2_quanted_weight_list )
0 commit comments