@@ -31,6 +31,8 @@ namespace framework {
3131namespace ir {
3232namespace patterns {
3333
34+ static const std::unordered_set<std::string> FFN_ACTS{" relu" , " gelu" };
35+
3436PDNode* FusedMultiTransformerDecoderPattern::operator ()() {
3537 auto * input0 = pattern->NewNode (input0_repr ());
3638 input0->assert_is_op_input (" layer_norm" , " X" );
@@ -359,13 +361,13 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
359361 auto * ffn_eltadd0_out_var = pattern->NewNode (ffn_eltadd0_out_repr ())
360362 ->assert_is_op_output (" elementwise_add" )
361363 ->AsIntermediate ()
362- ->assert_is_op_input ( " gelu " );
364+ ->assert_is_ops_input (FFN_ACTS );
363365
364- auto * ffn_gelu = pattern->NewNode (ffn_gelu_repr ())->assert_is_op ( " gelu " );
365- auto * ffn_gelu_out_var = pattern->NewNode (ffn_gelu_out_repr ())
366- -> assert_is_op_output ( " gelu " )
367- ->AsIntermediate ()
368- ->assert_is_op_input (" matmul_v2" );
366+ auto * ffn_act = pattern->NewNode (ffn_act_repr ())->assert_is_ops (FFN_ACTS );
367+ auto * ffn_act_out_var = pattern->NewNode (ffn_act_out_repr ())
368+ -> assert_is_ops_output (FFN_ACTS )
369+ ->AsIntermediate ()
370+ ->assert_is_op_input (" matmul_v2" );
369371
370372 auto * ffn_matmul1 =
371373 pattern->NewNode (ffn_matmul1_repr ())->assert_is_op (" matmul_v2" );
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
397399 .LinksTo ({ffn_matmul0_out_var});
398400 ffn_eltadd0->LinksFrom ({ffn_matmul0_out_var, ffn_eltadd0_b_var})
399401 .LinksTo ({ffn_eltadd0_out_var});
400- ffn_gelu ->LinksFrom ({ffn_eltadd0_out_var}).LinksTo ({ffn_gelu_out_var });
401- ffn_matmul1->LinksFrom ({ffn_gelu_out_var , ffn_matmul1_w_var})
402+ ffn_act ->LinksFrom ({ffn_eltadd0_out_var}).LinksTo ({ffn_act_out_var });
403+ ffn_matmul1->LinksFrom ({ffn_act_out_var , ffn_matmul1_w_var})
402404 .LinksTo ({ffn_matmul1_out_var});
403405 ffn_eltadd1->LinksFrom ({ffn_matmul1_out_var, ffn_eltadd1_b_var})
404406 .LinksTo ({ffn_eltadd1_out_var});
@@ -678,13 +680,13 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
678680 auto * ffn_eltadd0_out_var = pattern->NewNode (ffn_eltadd0_out_repr ())
679681 ->assert_is_op_output (" elementwise_add" )
680682 ->AsIntermediate ()
681- ->assert_is_op_input ( " gelu " );
683+ ->assert_is_ops_input (FFN_ACTS );
682684
683- auto * ffn_gelu = pattern->NewNode (ffn_gelu_repr ())->assert_is_op ( " gelu " );
684- auto * ffn_gelu_out_var = pattern->NewNode (ffn_gelu_out_repr ())
685- -> assert_is_op_output ( " gelu " )
686- ->AsIntermediate ()
687- ->assert_is_op_input (" matmul_v2" );
685+ auto * ffn_act = pattern->NewNode (ffn_act_repr ())->assert_is_ops (FFN_ACTS );
686+ auto * ffn_act_out_var = pattern->NewNode (ffn_act_out_repr ())
687+ -> assert_is_ops_output (FFN_ACTS )
688+ ->AsIntermediate ()
689+ ->assert_is_op_input (" matmul_v2" );
688690
689691 auto * ffn_matmul1 =
690692 pattern->NewNode (ffn_matmul1_repr ())->assert_is_op (" matmul_v2" );
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
716718 .LinksTo ({ffn_matmul0_out_var});
717719 ffn_eltadd0->LinksFrom ({ffn_matmul0_out_var, ffn_eltadd0_b_var})
718720 .LinksTo ({ffn_eltadd0_out_var});
719- ffn_gelu ->LinksFrom ({ffn_eltadd0_out_var}).LinksTo ({ffn_gelu_out_var });
720- ffn_matmul1->LinksFrom ({ffn_gelu_out_var , ffn_matmul1_w_var})
721+ ffn_act ->LinksFrom ({ffn_eltadd0_out_var}).LinksTo ({ffn_act_out_var });
722+ ffn_matmul1->LinksFrom ({ffn_act_out_var , ffn_matmul1_w_var})
721723 .LinksTo ({ffn_matmul1_out_var});
722724 ffn_eltadd1->LinksFrom ({ffn_matmul1_out_var, ffn_eltadd1_b_var})
723725 .LinksTo ({ffn_eltadd1_out_var});
@@ -1026,13 +1028,13 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
10261028 auto * ffn_eltadd0_out_var = pattern->NewNode (ffn_eltadd0_out_repr ())
10271029 ->assert_is_op_output (" elementwise_add" )
10281030 ->AsIntermediate ()
1029- ->assert_is_op_input ( " gelu " );
1031+ ->assert_is_ops_input (FFN_ACTS );
10301032
1031- auto * ffn_gelu = pattern->NewNode (ffn_gelu_repr ())->assert_is_op ( " gelu " );
1032- auto * ffn_gelu_out_var = pattern->NewNode (ffn_gelu_out_repr ())
1033- -> assert_is_op_output ( " gelu " )
1034- ->AsIntermediate ()
1035- ->assert_is_op_input (" matmul_v2" );
1033+ auto * ffn_act = pattern->NewNode (ffn_act_repr ())->assert_is_ops (FFN_ACTS );
1034+ auto * ffn_act_out_var = pattern->NewNode (ffn_act_out_repr ())
1035+ -> assert_is_ops_output (FFN_ACTS )
1036+ ->AsIntermediate ()
1037+ ->assert_is_op_input (" matmul_v2" );
10361038
10371039 auto * ffn_matmul1 =
10381040 pattern->NewNode (ffn_matmul1_repr ())->assert_is_op (" matmul_v2" );
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
10731075 .LinksTo ({ffn_matmul0_out_var});
10741076 ffn_eltadd0->LinksFrom ({ffn_matmul0_out_var, ffn_eltadd0_b_var})
10751077 .LinksTo ({ffn_eltadd0_out_var});
1076- ffn_gelu ->LinksFrom ({ffn_eltadd0_out_var}).LinksTo ({ffn_gelu_out_var });
1077- ffn_matmul1->LinksFrom ({ffn_gelu_out_var , ffn_matmul1_w_var})
1078+ ffn_act ->LinksFrom ({ffn_eltadd0_out_var}).LinksTo ({ffn_act_out_var });
1079+ ffn_matmul1->LinksFrom ({ffn_act_out_var , ffn_matmul1_w_var})
10781080 .LinksTo ({ffn_matmul1_out_var});
10791081 ffn_c_allreduce_sum->LinksFrom ({ffn_matmul1_out_var})
10801082 .LinksTo ({ffn_c_allreduce_sum_out_var});
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
11471149 Node* ffn_matmul1_w,
11481150 Node* ffn_eltadd0_b,
11491151 Node* ffn_eltadd1_b,
1152+ Node* ffn_act,
11501153 Node* ffn_output) {
11511154 auto * matmul0_op = matmul0->Op ();
11521155 auto * matmul_linear_op = matmul_linear->Op ();
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
12151218 fused_multi_transformer_op_desc.SetAttr (" pre_layer_norm" , true );
12161219 fused_multi_transformer_op_desc.SetAttr (
12171220 " epsilon" , layer_norm->Op ()->GetAttr (" epsilon" ));
1221+ fused_multi_transformer_op_desc.SetAttr (" act_method" ,
1222+ ffn_act->Op ()->Type ());
12181223
12191224 // output dropout attribute
12201225 fused_multi_transformer_op_desc.SetAttr (" is_test" , true );
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
14551460 ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
14561461
14571462 GET_IR_NODE_FROM_SUBGRAPH (
1458- ffn_gelu, ffn_gelu , fused_multi_transformer_pattern);
1463+ ffn_act, ffn_act , fused_multi_transformer_pattern);
14591464 GET_IR_NODE_FROM_SUBGRAPH (
1460- ffn_gelu_out, ffn_gelu_out , fused_multi_transformer_pattern);
1465+ ffn_act_out, ffn_act_out , fused_multi_transformer_pattern);
14611466
14621467 GET_IR_NODE_FROM_SUBGRAPH (
14631468 ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
15781583 ffn_matmul1_w,
15791584 ffn_eltadd0_b,
15801585 ffn_eltadd1_b,
1586+ ffn_act,
15811587 ffn_output);
15821588
15831589 std::unordered_set<const Node*> marked_nodes ({layer_norm,
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
16441650 ffn_eltadd1,
16451651 ffn_eltadd0_out,
16461652 ffn_eltadd1_out,
1647- ffn_gelu ,
1648- ffn_gelu_out ,
1653+ ffn_act ,
1654+ ffn_act_out ,
16491655 ffn_eltadd_out});
16501656
16511657 // Remove unneeded nodes.
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
18711877 Node* ffn_matmul1_w,
18721878 Node* ffn_eltadd0_b,
18731879 Node* ffn_eltadd1_b,
1880+ Node* ffn_act,
18741881 Node* ffn_output) {
18751882 auto * matmul0_op = matmul0->Op ();
18761883 auto * matmul_linear_op = matmul_linear->Op ();
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
19391946 fused_multi_transformer_op_desc.SetAttr (" pre_layer_norm" , true );
19401947 fused_multi_transformer_op_desc.SetAttr (
19411948 " epsilon" , layer_norm->Op ()->GetAttr (" epsilon" ));
1949+ fused_multi_transformer_op_desc.SetAttr (" act_method" ,
1950+ ffn_act->Op ()->Type ());
19421951
19431952 // output dropout attribute
19441953 fused_multi_transformer_op_desc.SetAttr (" dropout_rate" , 0 .0f );
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
21682177 fused_multi_transformer_fuse_qkv_pattern);
21692178
21702179 GET_IR_NODE_FROM_SUBGRAPH (
2171- ffn_gelu, ffn_gelu , fused_multi_transformer_fuse_qkv_pattern);
2180+ ffn_act, ffn_act , fused_multi_transformer_fuse_qkv_pattern);
21722181 GET_IR_NODE_FROM_SUBGRAPH (
2173- ffn_gelu_out, ffn_gelu_out , fused_multi_transformer_fuse_qkv_pattern);
2182+ ffn_act_out, ffn_act_out , fused_multi_transformer_fuse_qkv_pattern);
21742183
21752184 GET_IR_NODE_FROM_SUBGRAPH (
21762185 ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
22872296 ffn_matmul1_w,
22882297 ffn_eltadd0_b,
22892298 ffn_eltadd1_b,
2299+ ffn_act,
22902300 ffn_output);
22912301
22922302 std::unordered_set<const Node*> marked_nodes ({layer_norm,
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
23452355 ffn_eltadd1,
23462356 ffn_eltadd0_out,
23472357 ffn_eltadd1_out,
2348- ffn_gelu ,
2349- ffn_gelu_out ,
2358+ ffn_act ,
2359+ ffn_act_out ,
23502360 ffn_eltadd_out});
23512361
23522362 // Remove unneeded nodes.
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
25922602 Node* ffn_matmul1_w,
25932603 Node* ffn_eltadd0_b,
25942604 Node* ffn_eltadd1_b,
2605+ Node* ffn_act,
25952606 Node* ffn_output) {
25962607 auto * matmul_linear_op = matmul_linear->Op ();
25972608 auto * ffn_matmul_1_op = ffn_matmul1->Op ();
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
26582669 fused_multi_transformer_op_desc.SetAttr (" pre_layer_norm" , true );
26592670 fused_multi_transformer_op_desc.SetAttr (
26602671 " epsilon" , layer_norm->Op ()->GetAttr (" epsilon" ));
2672+ fused_multi_transformer_op_desc.SetAttr (" act_method" ,
2673+ ffn_act->Op ()->Type ());
26612674
26622675 // output dropout attribute
26632676 fused_multi_transformer_op_desc.SetAttr (" dropout_rate" , 0 .0f );
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
29112924 fused_multi_transformer_fuse_qkv_pattern);
29122925
29132926 GET_IR_NODE_FROM_SUBGRAPH (
2914- ffn_gelu, ffn_gelu , fused_multi_transformer_fuse_qkv_pattern);
2927+ ffn_act, ffn_act , fused_multi_transformer_fuse_qkv_pattern);
29152928 GET_IR_NODE_FROM_SUBGRAPH (
2916- ffn_gelu_out, ffn_gelu_out , fused_multi_transformer_fuse_qkv_pattern);
2929+ ffn_act_out, ffn_act_out , fused_multi_transformer_fuse_qkv_pattern);
29172930
29182931 GET_IR_NODE_FROM_SUBGRAPH (
29192932 ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
30443057 ffn_matmul1_w,
30453058 ffn_eltadd0_b,
30463059 ffn_eltadd1_b,
3060+ ffn_act,
30473061 ffn_output);
30483062
30493063 std::unordered_set<const Node*> marked_nodes ({layer_norm,
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
31103124 ffn_eltadd1,
31113125 ffn_eltadd0_out,
31123126 ffn_eltadd1_out,
3113- ffn_gelu ,
3114- ffn_gelu_out ,
3127+ ffn_act ,
3128+ ffn_act_out ,
31153129 ffn_eltadd_out});
31163130
31173131 // Remove unneeded nodes.
0 commit comments