@@ -48,18 +48,22 @@ class XPUSingleEncoderFuser : public FuseBase {
48
48
const std::string& input_pos = " Y" ,
49
49
const std::string& qkv_ln_2_out_pos = " Y" ,
50
50
const std::string& matmul_type = " matmul" ,
51
+ const std::string& matmul2_type = " matmul_v2" ,
51
52
const std::string& mul_type = " mul" ,
52
53
bool with_q_scale = true ,
53
54
bool norm_before = false ,
54
- const std::string& relative_type = " " )
55
+ const std::string& relative_type = " " ,
56
+ bool with_mask = true )
55
57
: act_type_(act_type),
56
58
input_pos_(input_pos),
57
59
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
58
60
matmul_type_(matmul_type),
61
+ matmul2_type_(matmul2_type),
59
62
mul_type_(mul_type),
60
63
with_q_scale_(with_q_scale),
61
64
norm_before_(norm_before),
62
- relative_emb_type_(relative_type) {}
65
+ relative_emb_type_(relative_type),
66
+ with_mask_(with_mask) {}
63
67
64
68
void BuildPattern () override {
65
69
auto * input = VarNode (" input" )
@@ -213,18 +217,25 @@ class XPUSingleEncoderFuser : public FuseBase {
213
217
->AsIntermediate ();
214
218
215
219
auto * qk_matmul = OpNode (" qk_matmul" , matmul_type_)->AsIntermediate ();
220
+ std::string op_after_qk_matmul = with_mask_ ? " elementwise_add" : " softmax" ;
216
221
auto * qk_matmul_out = VarNode (" qk_matmul_out" )
217
222
->assert_is_op_output (matmul_type_, " Out" )
218
- ->assert_is_op_input (" elementwise_add " , " X" )
223
+ ->assert_is_op_input (op_after_qk_matmul , " X" )
219
224
->AsIntermediate ();
220
- auto * qk_mask = VarNode (" qk_mask" )
221
- ->assert_is_op_input (" elementwise_add" , " Y" )
222
- ->AsInput ();
223
- auto * qk_add = OpNode (" qk_add" , " elementwise_add" )->AsIntermediate ();
224
- auto * qk_add_out = VarNode (" qk_add_out" )
225
- ->assert_is_op_output (" elementwise_add" , " Out" )
226
- ->assert_is_op_input (" softmax" , " X" )
227
- ->AsIntermediate ();
225
+ PMNode* qk_mask = nullptr ;
226
+ PMNode* qk_add = nullptr ;
227
+ PMNode* qk_add_out = nullptr ;
228
+ if (with_mask_) {
229
+ qk_mask = VarNode (" qk_mask" )
230
+ ->assert_is_op_input (" elementwise_add" , " Y" )
231
+ ->AsInput ();
232
+ qk_add = OpNode (" qk_add" , " elementwise_add" )->AsIntermediate ();
233
+ qk_add_out = VarNode (" qk_add_out" )
234
+ ->assert_is_op_output (" elementwise_add" , " Out" )
235
+ ->assert_is_op_input (" softmax" , " X" )
236
+ ->AsIntermediate ();
237
+ }
238
+
228
239
auto * qk_softmax = OpNode (" qk_softmax" , " softmax" )->AsIntermediate ();
229
240
auto * qk_softmax_out = VarNode (" qk_softmax_out" )
230
241
->assert_is_op_output (" softmax" , " Out" )
@@ -256,16 +267,16 @@ class XPUSingleEncoderFuser : public FuseBase {
256
267
auto * v_transpose2 = OpNode (" v_transpose2" , " transpose2" )->AsIntermediate ();
257
268
auto * v_transpose2_out = VarNode (" v_transpose2_out" )
258
269
->assert_is_op_output (" transpose2" , " Out" )
259
- ->assert_is_op_input (matmul_type_ , " Y" )
270
+ ->assert_is_op_input (matmul2_type_ , " Y" )
260
271
->AsIntermediate ();
261
272
auto * v_transpose2_xshape =
262
273
VarNode (" v_transpose2_xshape" )
263
274
->assert_is_op_output (" transpose2" , " XShape" )
264
275
->AsIntermediate ();
265
276
266
- auto * qkv_matmul = OpNode (" qkv_matmul" , matmul_type_ )->AsIntermediate ();
277
+ auto * qkv_matmul = OpNode (" qkv_matmul" , matmul2_type_ )->AsIntermediate ();
267
278
auto * qkv_matmul_out = VarNode (" qkv_matmul_out" )
268
- ->assert_is_op_output (matmul_type_ , " Out" )
279
+ ->assert_is_op_output (matmul2_type_ , " Out" )
269
280
->assert_is_op_input (" transpose2" , " X" )
270
281
->AsIntermediate ();
271
282
auto * qkv_transpose2 =
@@ -459,9 +470,14 @@ class XPUSingleEncoderFuser : public FuseBase {
459
470
*k_reshape2 >> *k_reshape2_xshape;
460
471
*k_transpose2 >> *k_transpose2_xshape;
461
472
462
- *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
463
- *qk_softmax_out >> *qkv_matmul;
464
- *qk_mask >> *qk_add;
473
+ if (with_mask_) {
474
+ *qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
475
+ *qk_softmax_out >> *qkv_matmul;
476
+ *qk_mask >> *qk_add;
477
+ } else {
478
+ *qk_matmul >> *qk_matmul_out >> *qk_softmax >> *qk_softmax_out >>
479
+ *qkv_matmul;
480
+ }
465
481
466
482
if (norm_before_) {
467
483
*ln_before_out >> *v_mul;
@@ -513,7 +529,9 @@ class XPUSingleEncoderFuser : public FuseBase {
513
529
cpp::OpDesc op_desc;
514
530
op_desc.SetType (" single_encoder" );
515
531
op_desc.SetInput (" Inputs" , {matched.at (" input" )->arg ()->name });
516
- op_desc.SetInput (" Mask" , {matched.at (" qk_mask" )->arg ()->name });
532
+ if (with_mask_) {
533
+ op_desc.SetInput (" Mask" , {matched.at (" qk_mask" )->arg ()->name });
534
+ }
517
535
op_desc.SetInput (" FCWeight" ,
518
536
{
519
537
matched.at (" q_mul_y" )->arg ()->name ,
@@ -645,7 +663,6 @@ class XPUSingleEncoderFuser : public FuseBase {
645
663
single_encoder_stmt->SetOp (fake_subgraph_op);
646
664
647
665
std::vector<std::string> froms = {
648
- " qk_mask" ,
649
666
" k_mul_y" ,
650
667
" v_mul_y" ,
651
668
" qkv_mul_y" ,
@@ -660,6 +677,9 @@ class XPUSingleEncoderFuser : public FuseBase {
660
677
" qkv_ln_2_scale" ,
661
678
" qkv_ln_2_bias" ,
662
679
};
680
+ if (with_mask_) {
681
+ froms.push_back (" qk_mask" );
682
+ }
663
683
if (relative_emb_type_ == " __xpu__roformer_relative_embedding" ) {
664
684
froms.push_back (" q_cos_embedding" );
665
685
froms.push_back (" q_sin_embedding" );
@@ -687,10 +707,12 @@ class XPUSingleEncoderFuser : public FuseBase {
687
707
std::string input_pos_;
688
708
std::string qkv_ln_2_out_pos_;
689
709
std::string matmul_type_;
710
+ std::string matmul2_type_;
690
711
std::string mul_type_;
691
712
bool with_q_scale_;
692
713
bool norm_before_;
693
714
const std::string relative_emb_type_;
715
+ bool with_mask_;
694
716
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max
695
717
// * 2
696
718
void set_quant_info (Scope* scope,
@@ -955,7 +977,7 @@ class XPUMultiEncoderFuser {
955
977
std::string mask_name;
956
978
for (auto * encoder : all_encoders) {
957
979
auto * op_info = encoder->stmt ()->op_info ();
958
- if (mask_name.empty ()) {
980
+ if (mask_name.empty () && op_info-> HasInput ( " Mask " ) ) {
959
981
mask_name = op_info->Input (" Mask" ).front ();
960
982
} else {
961
983
// CHECK(mask_name == op_info->Input("Mask").front());
@@ -1026,13 +1048,11 @@ class XPUMultiEncoderFuser {
1026
1048
if (all_encoders.size () == 1 ) {
1027
1049
// take care of only one encoder
1028
1050
in_name = op_info->Input (" Inputs" ).front ();
1029
- mask_name = op_info->Input (" Mask" ).front ();
1030
1051
out_name = op_info->Output (" Outputs" ).front ();
1031
1052
} else if (i == 0 ) {
1032
1053
// first encoder
1033
1054
to_remove.insert (cur_out);
1034
1055
in_name = op_info->Input (" Inputs" ).front ();
1035
- mask_name = op_info->Input (" Mask" ).front ();
1036
1056
} else if (i == all_encoders.size () - 1 ) {
1037
1057
// last encoder
1038
1058
to_remove.insert (cur_encoder);
@@ -1059,7 +1079,9 @@ class XPUMultiEncoderFuser {
1059
1079
for (auto kv : arg_map) {
1060
1080
op_desc.SetInput (kv.first , kv.second );
1061
1081
}
1062
- op_desc.SetInput (" Mask" , {mask_name});
1082
+ if (!mask_name.empty ()) {
1083
+ op_desc.SetInput (" Mask" , {mask_name});
1084
+ }
1063
1085
op_desc.SetOutput (" Output" , {out_name});
1064
1086
op_desc.SetAttr <int >(" xpu" , 1 );
1065
1087
op_desc.SetAttr <int >(
@@ -1404,9 +1426,11 @@ class XPUMultiEncoderFusePass : public ProgramPass {
1404
1426
std::vector<std::string> input_poss{" X" , " Y" };
1405
1427
std::vector<std::string> qkv_ln_2_out_poss{" X" , " Y" };
1406
1428
std::vector<std::string> matmul_types{" matmul" , " matmul_v2" };
1429
+ std::vector<std::string> matmul2_types{" matmul" , " matmul_v2" };
1407
1430
std::vector<std::string> mul_types{" mul" , " matmul" , " matmul_v2" };
1408
1431
std::vector<bool > with_q_scales{true , false };
1409
1432
std::vector<bool > norm_befores{true , false };
1433
+ std::vector<bool > with_mask{true , false };
1410
1434
std::vector<std::string> relative_embedding_type{
1411
1435
" " , " __xpu__roformer_relative_embedding" };
1412
1436
@@ -1445,23 +1469,29 @@ class XPUMultiEncoderFusePass : public ProgramPass {
1445
1469
for (auto & input_pos : input_poss) {
1446
1470
for (auto & qkv_ln_2_out_pos : qkv_ln_2_out_poss) {
1447
1471
for (auto & matmul_type : matmul_types) {
1448
- for (auto & mul_type : mul_types) {
1449
- for (auto with_q_scale : with_q_scales) {
1450
- for (auto norm_before : norm_befores) {
1451
- for (auto relative_type : relative_embedding_type) {
1452
- fusion::XPUSingleEncoderFuser single_encoder_fuser (
1453
- act_type,
1454
- input_pos,
1455
- qkv_ln_2_out_pos,
1456
- matmul_type,
1457
- mul_type,
1458
- with_q_scale,
1459
- norm_before,
1460
- relative_type);
1461
- single_encoder_fuser (graph.get ());
1462
- fusion::XPUMultiEncoderFuser multi_encoder_fuser (
1463
- fc_precision, adaptive_seqlen);
1464
- multi_encoder_fuser (graph.get ());
1472
+ for (auto & matmul2_type : matmul2_types) {
1473
+ for (auto & mul_type : mul_types) {
1474
+ for (auto with_q_scale : with_q_scales) {
1475
+ for (auto norm_before : norm_befores) {
1476
+ for (auto relative_type : relative_embedding_type) {
1477
+ for (auto mask : with_mask) {
1478
+ fusion::XPUSingleEncoderFuser single_encoder_fuser (
1479
+ act_type,
1480
+ input_pos,
1481
+ qkv_ln_2_out_pos,
1482
+ matmul_type,
1483
+ matmul2_type,
1484
+ mul_type,
1485
+ with_q_scale,
1486
+ norm_before,
1487
+ relative_type,
1488
+ mask);
1489
+ single_encoder_fuser (graph.get ());
1490
+ fusion::XPUMultiEncoderFuser multi_encoder_fuser (
1491
+ fc_precision, adaptive_seqlen);
1492
+ multi_encoder_fuser (graph.get ());
1493
+ }
1494
+ }
1465
1495
}
1466
1496
}
1467
1497
}
0 commit comments