Skip to content

Commit b13fd58

Browse files
penPenf28lixcli
authored andcommitted
[Feature] fused mixtral wint4 (PaddlePaddle#9013)
* [Feature] fused mixtral wint4 * [Refactor] refine code
1 parent 2bca503 commit b13fd58

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,10 @@ def init_weight_shape(self, config):
11371137
)
11381138
self.moe_ffn2_weight_shape = [self.config.moe_config.num_experts, self.dim_feedforward, self.embed_dim]
11391139

1140+
if config.quant_type == "weight_only_int4":
1141+
self.moe_ffn1_weight_shape[2] //= 2
1142+
self.moe_ffn2_weight_shape[2] //= 2
1143+
11401144
def compute_qkv_linear(self, ln_out, i):
11411145
return weight_only_linear(
11421146
ln_out,

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,11 @@ def set_state_dict(self, state_dict):
642642
ffn1_weight_tensor[i], algo=self.quant_algo
643643
)
644644
ffn1_quanted_weight_list.append(
645-
ffn1_quanted_weight_list_i.reshape([self.transformer_block.config.embed_dim, -1])
645+
ffn1_quanted_weight_list_i.reshape(
646+
[self.transformer_block.embed_dim, self.transformer_block.dim_feedforward * 2]
647+
if self.quant_type == "weight_only_int8"
648+
else [self.transformer_block.embed_dim, self.transformer_block.dim_feedforward]
649+
)
646650
)
647651
ffn1_quanted_weight_scale.append(ffn1_quanted_weight_scale_i)
648652
ffn1_quanted_weight_tensor = paddle.to_tensor(ffn1_quanted_weight_list)
@@ -677,7 +681,11 @@ def set_state_dict(self, state_dict):
677681
ffn2_weight_tensor[i], algo=self.quant_algo
678682
)
679683
ffn2_quanted_weight_list.append(
680-
ffn2_quanted_weight_list_i.reshape([-1, self.transformer_block.config.embed_dim])
684+
ffn2_quanted_weight_list_i.reshape(
685+
[self.transformer_block.dim_feedforward, self.transformer_block.embed_dim]
686+
if self.quant_type == "weight_only_int8"
687+
else [self.transformer_block.dim_feedforward, self.transformer_block.embed_dim // 2]
688+
)
681689
)
682690
ffn2_quanted_weight_scale.append(ffn2_quanted_weight_scale_i)
683691
ffn2_quanted_weight_tensor = paddle.to_tensor(ffn2_quanted_weight_list)

0 commit comments

Comments
 (0)