Skip to content

Commit 29eec2d

Browse files
authored
add multi_devices_fused_multi_transformer_encoder_pass and cherry-pick from 48349 (#49383)
1 parent a2d7e1d commit 29eec2d

File tree

6 files changed

+2776
-1241
lines changed

6 files changed

+2776
-1241
lines changed

paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.cc

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ namespace framework {
3131
namespace ir {
3232
namespace patterns {
3333

34+
static const std::unordered_set<std::string> FFN_ACTS{"relu", "gelu"};
35+
3436
PDNode* FusedMultiTransformerDecoderPattern::operator()() {
3537
auto* input0 = pattern->NewNode(input0_repr());
3638
input0->assert_is_op_input("layer_norm", "X");
@@ -359,13 +361,13 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
359361
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
360362
->assert_is_op_output("elementwise_add")
361363
->AsIntermediate()
362-
->assert_is_op_input("gelu");
364+
->assert_is_ops_input(FFN_ACTS);
363365

364-
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
365-
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
366-
->assert_is_op_output("gelu")
367-
->AsIntermediate()
368-
->assert_is_op_input("matmul_v2");
366+
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
367+
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
368+
->assert_is_ops_output(FFN_ACTS)
369+
->AsIntermediate()
370+
->assert_is_op_input("matmul_v2");
369371

370372
auto* ffn_matmul1 =
371373
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
@@ -397,8 +399,8 @@ PDNode* FusedMultiTransformerDecoderPattern::operator()() {
397399
.LinksTo({ffn_matmul0_out_var});
398400
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
399401
.LinksTo({ffn_eltadd0_out_var});
400-
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
401-
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
402+
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
403+
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
402404
.LinksTo({ffn_matmul1_out_var});
403405
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
404406
.LinksTo({ffn_eltadd1_out_var});
@@ -678,13 +680,13 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
678680
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
679681
->assert_is_op_output("elementwise_add")
680682
->AsIntermediate()
681-
->assert_is_op_input("gelu");
683+
->assert_is_ops_input(FFN_ACTS);
682684

683-
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
684-
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
685-
->assert_is_op_output("gelu")
686-
->AsIntermediate()
687-
->assert_is_op_input("matmul_v2");
685+
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
686+
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
687+
->assert_is_ops_output(FFN_ACTS)
688+
->AsIntermediate()
689+
->assert_is_op_input("matmul_v2");
688690

689691
auto* ffn_matmul1 =
690692
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
@@ -716,8 +718,8 @@ PDNode* FusedMultiTransformerDecoderFuseQKVPattern::operator()() {
716718
.LinksTo({ffn_matmul0_out_var});
717719
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
718720
.LinksTo({ffn_eltadd0_out_var});
719-
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
720-
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
721+
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
722+
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
721723
.LinksTo({ffn_matmul1_out_var});
722724
ffn_eltadd1->LinksFrom({ffn_matmul1_out_var, ffn_eltadd1_b_var})
723725
.LinksTo({ffn_eltadd1_out_var});
@@ -1026,13 +1028,13 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
10261028
auto* ffn_eltadd0_out_var = pattern->NewNode(ffn_eltadd0_out_repr())
10271029
->assert_is_op_output("elementwise_add")
10281030
->AsIntermediate()
1029-
->assert_is_op_input("gelu");
1031+
->assert_is_ops_input(FFN_ACTS);
10301032

1031-
auto* ffn_gelu = pattern->NewNode(ffn_gelu_repr())->assert_is_op("gelu");
1032-
auto* ffn_gelu_out_var = pattern->NewNode(ffn_gelu_out_repr())
1033-
->assert_is_op_output("gelu")
1034-
->AsIntermediate()
1035-
->assert_is_op_input("matmul_v2");
1033+
auto* ffn_act = pattern->NewNode(ffn_act_repr())->assert_is_ops(FFN_ACTS);
1034+
auto* ffn_act_out_var = pattern->NewNode(ffn_act_out_repr())
1035+
->assert_is_ops_output(FFN_ACTS)
1036+
->AsIntermediate()
1037+
->assert_is_op_input("matmul_v2");
10361038

10371039
auto* ffn_matmul1 =
10381040
pattern->NewNode(ffn_matmul1_repr())->assert_is_op("matmul_v2");
@@ -1073,8 +1075,8 @@ PDNode* MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern::operator()() {
10731075
.LinksTo({ffn_matmul0_out_var});
10741076
ffn_eltadd0->LinksFrom({ffn_matmul0_out_var, ffn_eltadd0_b_var})
10751077
.LinksTo({ffn_eltadd0_out_var});
1076-
ffn_gelu->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_gelu_out_var});
1077-
ffn_matmul1->LinksFrom({ffn_gelu_out_var, ffn_matmul1_w_var})
1078+
ffn_act->LinksFrom({ffn_eltadd0_out_var}).LinksTo({ffn_act_out_var});
1079+
ffn_matmul1->LinksFrom({ffn_act_out_var, ffn_matmul1_w_var})
10781080
.LinksTo({ffn_matmul1_out_var});
10791081
ffn_c_allreduce_sum->LinksFrom({ffn_matmul1_out_var})
10801082
.LinksTo({ffn_c_allreduce_sum_out_var});
@@ -1147,6 +1149,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
11471149
Node* ffn_matmul1_w,
11481150
Node* ffn_eltadd0_b,
11491151
Node* ffn_eltadd1_b,
1152+
Node* ffn_act,
11501153
Node* ffn_output) {
11511154
auto* matmul0_op = matmul0->Op();
11521155
auto* matmul_linear_op = matmul_linear->Op();
@@ -1215,6 +1218,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
12151218
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
12161219
fused_multi_transformer_op_desc.SetAttr(
12171220
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
1221+
fused_multi_transformer_op_desc.SetAttr("act_method",
1222+
ffn_act->Op()->Type());
12181223

12191224
// output dropout attribute
12201225
fused_multi_transformer_op_desc.SetAttr("is_test", true);
@@ -1455,9 +1460,9 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
14551460
ffn_eltadd0_out, ffn_eltadd0_out, fused_multi_transformer_pattern);
14561461

14571462
GET_IR_NODE_FROM_SUBGRAPH(
1458-
ffn_gelu, ffn_gelu, fused_multi_transformer_pattern);
1463+
ffn_act, ffn_act, fused_multi_transformer_pattern);
14591464
GET_IR_NODE_FROM_SUBGRAPH(
1460-
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_pattern);
1465+
ffn_act_out, ffn_act_out, fused_multi_transformer_pattern);
14611466

14621467
GET_IR_NODE_FROM_SUBGRAPH(
14631468
ffn_matmul1, ffn_matmul1, fused_multi_transformer_pattern);
@@ -1578,6 +1583,7 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
15781583
ffn_matmul1_w,
15791584
ffn_eltadd0_b,
15801585
ffn_eltadd1_b,
1586+
ffn_act,
15811587
ffn_output);
15821588

15831589
std::unordered_set<const Node*> marked_nodes({layer_norm,
@@ -1644,8 +1650,8 @@ int FusedMultiTransformerDecoderPass::BuildFusion(Graph* graph,
16441650
ffn_eltadd1,
16451651
ffn_eltadd0_out,
16461652
ffn_eltadd1_out,
1647-
ffn_gelu,
1648-
ffn_gelu_out,
1653+
ffn_act,
1654+
ffn_act_out,
16491655
ffn_eltadd_out});
16501656

16511657
// Remove unneeded nodes.
@@ -1871,6 +1877,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
18711877
Node* ffn_matmul1_w,
18721878
Node* ffn_eltadd0_b,
18731879
Node* ffn_eltadd1_b,
1880+
Node* ffn_act,
18741881
Node* ffn_output) {
18751882
auto* matmul0_op = matmul0->Op();
18761883
auto* matmul_linear_op = matmul_linear->Op();
@@ -1939,6 +1946,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
19391946
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
19401947
fused_multi_transformer_op_desc.SetAttr(
19411948
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
1949+
fused_multi_transformer_op_desc.SetAttr("act_method",
1950+
ffn_act->Op()->Type());
19421951

19431952
// output dropout attribute
19441953
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
@@ -2168,9 +2177,9 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
21682177
fused_multi_transformer_fuse_qkv_pattern);
21692178

21702179
GET_IR_NODE_FROM_SUBGRAPH(
2171-
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
2180+
ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
21722181
GET_IR_NODE_FROM_SUBGRAPH(
2173-
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
2182+
ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
21742183

21752184
GET_IR_NODE_FROM_SUBGRAPH(
21762185
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
@@ -2287,6 +2296,7 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
22872296
ffn_matmul1_w,
22882297
ffn_eltadd0_b,
22892298
ffn_eltadd1_b,
2299+
ffn_act,
22902300
ffn_output);
22912301

22922302
std::unordered_set<const Node*> marked_nodes({layer_norm,
@@ -2345,8 +2355,8 @@ int FusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
23452355
ffn_eltadd1,
23462356
ffn_eltadd0_out,
23472357
ffn_eltadd1_out,
2348-
ffn_gelu,
2349-
ffn_gelu_out,
2358+
ffn_act,
2359+
ffn_act_out,
23502360
ffn_eltadd_out});
23512361

23522362
// Remove unneeded nodes.
@@ -2592,6 +2602,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
25922602
Node* ffn_matmul1_w,
25932603
Node* ffn_eltadd0_b,
25942604
Node* ffn_eltadd1_b,
2605+
Node* ffn_act,
25952606
Node* ffn_output) {
25962607
auto* matmul_linear_op = matmul_linear->Op();
25972608
auto* ffn_matmul_1_op = ffn_matmul1->Op();
@@ -2658,6 +2669,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
26582669
fused_multi_transformer_op_desc.SetAttr("pre_layer_norm", true);
26592670
fused_multi_transformer_op_desc.SetAttr(
26602671
"epsilon", layer_norm->Op()->GetAttr("epsilon"));
2672+
fused_multi_transformer_op_desc.SetAttr("act_method",
2673+
ffn_act->Op()->Type());
26612674

26622675
// output dropout attribute
26632676
fused_multi_transformer_op_desc.SetAttr("dropout_rate", 0.0f);
@@ -2911,9 +2924,9 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
29112924
fused_multi_transformer_fuse_qkv_pattern);
29122925

29132926
GET_IR_NODE_FROM_SUBGRAPH(
2914-
ffn_gelu, ffn_gelu, fused_multi_transformer_fuse_qkv_pattern);
2927+
ffn_act, ffn_act, fused_multi_transformer_fuse_qkv_pattern);
29152928
GET_IR_NODE_FROM_SUBGRAPH(
2916-
ffn_gelu_out, ffn_gelu_out, fused_multi_transformer_fuse_qkv_pattern);
2929+
ffn_act_out, ffn_act_out, fused_multi_transformer_fuse_qkv_pattern);
29172930

29182931
GET_IR_NODE_FROM_SUBGRAPH(
29192932
ffn_matmul1, ffn_matmul1, fused_multi_transformer_fuse_qkv_pattern);
@@ -3044,6 +3057,7 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
30443057
ffn_matmul1_w,
30453058
ffn_eltadd0_b,
30463059
ffn_eltadd1_b,
3060+
ffn_act,
30473061
ffn_output);
30483062

30493063
std::unordered_set<const Node*> marked_nodes({layer_norm,
@@ -3110,8 +3124,8 @@ int MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::BuildFusion(
31103124
ffn_eltadd1,
31113125
ffn_eltadd0_out,
31123126
ffn_eltadd1_out,
3113-
ffn_gelu,
3114-
ffn_gelu_out,
3127+
ffn_act,
3128+
ffn_act_out,
31153129
ffn_eltadd_out});
31163130

31173131
// Remove unneeded nodes.

paddle/fluid/framework/ir/fused_multi_transformer_decoder_pass.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ struct FusedMultiTransformerDecoderPattern : public PatternBase {
125125
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
126126
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
127127
PATTERN_DECL_NODE(ffn_eltadd0_out);
128-
PATTERN_DECL_NODE(ffn_gelu);
129-
PATTERN_DECL_NODE(ffn_gelu_out);
128+
PATTERN_DECL_NODE(ffn_act);
129+
PATTERN_DECL_NODE(ffn_act_out);
130130
PATTERN_DECL_NODE(ffn_matmul1);
131131
PATTERN_DECL_NODE(ffn_matmul1_w);
132132
PATTERN_DECL_NODE(ffn_matmul1_out);
@@ -223,8 +223,8 @@ struct FusedMultiTransformerDecoderFuseQKVPattern : public PatternBase {
223223
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
224224
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
225225
PATTERN_DECL_NODE(ffn_eltadd0_out);
226-
PATTERN_DECL_NODE(ffn_gelu);
227-
PATTERN_DECL_NODE(ffn_gelu_out);
226+
PATTERN_DECL_NODE(ffn_act);
227+
PATTERN_DECL_NODE(ffn_act_out);
228228
PATTERN_DECL_NODE(ffn_matmul1);
229229
PATTERN_DECL_NODE(ffn_matmul1_w);
230230
PATTERN_DECL_NODE(ffn_matmul1_out);
@@ -329,8 +329,8 @@ struct MultiDevicesFusedMultiTransformerDecoderFuseQKVPattern
329329
PATTERN_DECL_NODE(ffn_eltadd0); // ELEMENTWISE_ADD
330330
PATTERN_DECL_NODE(ffn_eltadd0_b); // ELEMENTWISE_ADD
331331
PATTERN_DECL_NODE(ffn_eltadd0_out);
332-
PATTERN_DECL_NODE(ffn_gelu);
333-
PATTERN_DECL_NODE(ffn_gelu_out);
332+
PATTERN_DECL_NODE(ffn_act);
333+
PATTERN_DECL_NODE(ffn_act_out);
334334
PATTERN_DECL_NODE(ffn_matmul1);
335335
PATTERN_DECL_NODE(ffn_matmul1_w);
336336
PATTERN_DECL_NODE(ffn_matmul1_out);

0 commit comments

Comments
 (0)