1
+ // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2
+ //
3
+ // Licensed under the Apache License, Version 2.0 (the "License");
4
+ // you may not use this file except in compliance with the License.
5
+ // You may obtain a copy of the License at
6
+ //
7
+ // http://www.apache.org/licenses/LICENSE-2.0
8
+ //
9
+ // Unless required by applicable law or agreed to in writing, software
10
+ // distributed under the License is distributed on an "AS IS" BASIS,
11
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ // See the License for the specific language governing permissions and
13
+ // limitations under the License.
14
+
15
+ #include < memory>
16
+ #include < string>
17
+
18
+ #include " lite/backends/xpu/math.h"
19
+ #include " lite/core/optimizer/mir/pass_registry.h"
20
+ #include " lite/core/optimizer/mir/pattern_matcher_high_api.h"
21
+
22
+ namespace paddle {
23
+ namespace lite {
24
+ namespace mir {
25
+ namespace fusion {
26
+
27
+ /* support adaptive seq len for bert/ernie */
28
+ /* in_Input in_Mask fill_constant */
29
+ /* | \ / */
30
+ /* | \ / */
31
+ /* | | */
32
+ /* xpu_embedding equal */
33
+ /* | | */
34
+ /* | | */
35
+ /* layer_norm cast */
36
+ /* | | */
37
+ /* | scale */
38
+ /* | / */
39
+ /* | unsqueeze2 */
40
+ /* | | */
41
+ /* | / */
42
+ /* | / */
43
+ /* xpu_encoder */
44
+ /* | */
45
+ /* | */
46
+ /* out_Output */
47
+ /* ---------------------------------------------------*/
48
+ /* After the pass apply: */
49
+ /* in_Input in_Mask */
50
+ /* | | */
51
+ /* | | */
52
+ /* | / */
53
+ /* xpu_embedding */
54
+ /* | \ */
55
+ /* | SeqLod */
56
+ /* | | */
57
+ /* layer_norm | */
58
+ /* | | */
59
+ /* | / */
60
+ /* xpu_encoder */
61
+ /* | */
62
+ /* | */
63
+ /* out_Output */
64
+ /* ---------------------------------------------------*/
65
+
66
+ class XPUMultiEncoderAdaptiveSeqlenV2Fuser : public FuseBase {
67
+ public:
68
+ explicit XPUMultiEncoderAdaptiveSeqlenV2Fuser (bool pre_ln = false )
69
+ : pre_ln_(pre_ln) {}
70
+
71
+ void BuildPattern () override {
72
+ auto * mask = VarNode (" mask" )->assert_is_op_input (" equal" , " X" )->AsInput ();
73
+ auto * fill_constant =
74
+ OpNode (" fill_constant" , " fill_constant" )->AsIntermediate ();
75
+ auto * fill_constant_out = VarNode (" fill_constant_out" )
76
+ ->assert_is_op_output (" fill_constant" , " Out" )
77
+ ->assert_is_op_input (" equal" , " Y" )
78
+ ->AsIntermediate ();
79
+
80
+ // auto* fill_constant = VarNode("fill_constant")
81
+ // ->assert_is_op_input("equal", "Y")->AsInput();
82
+
83
+ auto * equal = OpNode (" equal" , " equal" )->AsIntermediate ();
84
+ auto * equal_out = VarNode (" equal_out" )
85
+ ->assert_is_op_output (" equal" , " Out" )
86
+ ->assert_is_op_input (" cast" , " X" )
87
+ ->AsIntermediate ();
88
+ auto * cast = OpNode (" cast" , " cast" )->AsIntermediate ();
89
+ auto * cast_out = VarNode (" cast_out" )
90
+ ->assert_is_op_output (" cast" , " Out" )
91
+ ->assert_is_op_input (" scale" , " X" )
92
+ ->AsIntermediate ();
93
+ auto * scale = OpNode (" scale" , " scale" )->AsIntermediate ();
94
+ auto * scale_out = VarNode (" scale_out" )
95
+ ->assert_is_op_output (" scale" , " Out" )
96
+ ->assert_is_op_input (" unsqueeze2" , " X" )
97
+ ->AsIntermediate ();
98
+ auto * unsqueeze2 = OpNode (" unsqueeze2" , " unsqueeze2" )->AsIntermediate ();
99
+ auto * unsqueeze2_out =
100
+ VarNode (" unsqueeze2_out" )
101
+ ->assert_is_op_output (" unsqueeze2" , " Out" )
102
+ ->assert_is_op_input (" __xpu__multi_encoder" , " Mask" )
103
+ ->AsIntermediate ();
104
+ auto * unsqueeze2_out_xshape =
105
+ VarNode (" unsqueeze2_out_xshape" )
106
+ ->assert_is_op_output (" unsqueeze2" , " XShape" )
107
+ ->AsIntermediate ();
108
+ auto * xpu_embedding =
109
+ OpNode (" xpu_embedding" , " __xpu__embedding_with_eltwise_add" );
110
+
111
+ PMNode* embedding_out = nullptr ;
112
+ PMNode* layer_norm = nullptr ;
113
+ PMNode* layer_norm_out = nullptr ;
114
+
115
+ if (pre_ln_) {
116
+ embedding_out = VarNode (" embedding_out" )
117
+ ->assert_is_op_output (
118
+ " __xpu__embedding_with_eltwise_add" , " Output" )
119
+ ->assert_is_op_input (" __xpu__multi_encoder" , " Input" );
120
+ } else {
121
+ embedding_out = VarNode (" embedding_out" )
122
+ ->assert_is_op_output (
123
+ " __xpu__embedding_with_eltwise_add" , " Output" )
124
+ ->assert_is_op_input (" layer_norm" , " X" );
125
+ layer_norm = OpNode (" layer_norm" , " layer_norm" );
126
+ layer_norm_out =
127
+ VarNode (" layer_norm_out" )
128
+ ->assert_is_op_output (" layer_norm" , " Y" )
129
+ ->assert_is_op_input (" __xpu__multi_encoder" , " Input" );
130
+ }
131
+ auto * xpu_encoder = OpNode (" xpu_encoder" , " __xpu__multi_encoder" )
132
+ ->assert_op_attr <bool >(" adaptive_seqlen" , true );
133
+ if (pre_ln_) {
134
+ xpu_encoder->assert_op_attr <bool >(" norm_before" , true );
135
+ *xpu_embedding >> *embedding_out >> *xpu_encoder;
136
+ } else {
137
+ *xpu_embedding >> *embedding_out >> *layer_norm >> *layer_norm_out >>
138
+ *xpu_encoder;
139
+ }
140
+ *mask >> *equal;
141
+ *fill_constant >> *fill_constant_out >> *equal;
142
+ *equal >> *equal_out >> *cast >> *cast_out >> *scale >> *scale_out >>
143
+ *unsqueeze2 >> *unsqueeze2_out >> *xpu_encoder;
144
+ *unsqueeze2 >> *unsqueeze2_out_xshape;
145
+ }
146
+
147
+ void InsertNewNode (SSAGraph* graph, const key2nodes_t & matched) override {
148
+ auto * embedding_instruct = matched.at (" xpu_embedding" )->stmt ();
149
+ auto embedding_op_desc = *embedding_instruct->mutable_op_info ();
150
+ auto embedding_op = embedding_instruct->op ();
151
+ auto * scope = embedding_op->scope ();
152
+ auto * encoder_instruct = matched.at (" xpu_encoder" )->stmt ();
153
+ auto encoder_op_desc = *encoder_instruct->mutable_op_info ();
154
+ auto encoder_op = encoder_instruct->op ();
155
+
156
+ // add new arg seq_lod
157
+ std::string embedding_out_name = matched.at (" embedding_out" )->arg ()->name ;
158
+ std::string embedding_seq_lod_name = embedding_out_name + " _seq_lod" ;
159
+ auto * embedding_seq_lod_node =
160
+ graph->NewArgumentNode (embedding_seq_lod_name);
161
+ embedding_seq_lod_node->arg ()->type = LiteType::GetTensorTy (
162
+ TARGET (kHost ), PRECISION (kInt32 ), DATALAYOUT (kNCHW ));
163
+ scope->NewTensor (embedding_seq_lod_name);
164
+ // add new arg pad_seq_len
165
+ std::string embedding_pad_seq_len_name =
166
+ embedding_out_name + " _pad_seq_len" ;
167
+ auto * embedding_pad_seq_len_node =
168
+ graph->NewArgumentNode (embedding_pad_seq_len_name);
169
+ embedding_pad_seq_len_node->arg ()->type = LiteType::GetTensorTy (
170
+ TARGET (kHost ), PRECISION (kInt32 ), DATALAYOUT (kNCHW ));
171
+ scope->NewTensor (embedding_pad_seq_len_name);
172
+
173
+ embedding_op_desc.SetOutput (" SeqLod" , {embedding_seq_lod_name});
174
+ embedding_op_desc.SetOutput (" PadSeqLen" , {embedding_pad_seq_len_name});
175
+ encoder_op_desc.SetInput (" SeqLod" , {embedding_seq_lod_name});
176
+ encoder_op_desc.SetInput (" PadSeqLen" , {embedding_pad_seq_len_name});
177
+
178
+ embedding_op_desc.SetInput (" Mask" , {matched.at (" mask" )->arg ()->name });
179
+ // add dtype
180
+ embedding_op_desc.SetAttr <int >(
181
+ " mask_dtype" , static_cast <int >(VarDescAPI::VarDataType::INT64));
182
+ embedding_instruct->ResetOp (embedding_op_desc,
183
+ embedding_op->valid_places ());
184
+ encoder_instruct->ResetOp (encoder_op_desc, encoder_op->valid_places ());
185
+ DirectedLink (matched.at (" xpu_embedding" ), embedding_seq_lod_node);
186
+ DirectedLink (matched.at (" xpu_embedding" ), embedding_pad_seq_len_node);
187
+ DirectedLink (matched.at (" mask" ), matched.at (" xpu_embedding" ));
188
+ DirectedLink (embedding_seq_lod_node, matched.at (" xpu_encoder" ));
189
+ DirectedLink (embedding_pad_seq_len_node, matched.at (" xpu_encoder" ));
190
+ }
191
+
192
+ private:
193
+ bool pre_ln_;
194
+ };
195
+
196
+ } // namespace fusion
197
+
198
+ class XPUMultiEncoderAdaptiveSeqlenV2FusePass : public ProgramPass {
199
+ public:
200
+ void Apply (const std::unique_ptr<SSAGraph>& graph) override {
201
+ std::vector<bool > pre_lns{true , false };
202
+ for (auto pre_ln : pre_lns) {
203
+ fusion::XPUMultiEncoderAdaptiveSeqlenV2Fuser fuser (pre_ln);
204
+ fuser (graph.get ());
205
+ }
206
+ }
207
+ };
208
+
209
+ } // namespace mir
210
+ } // namespace lite
211
+ } // namespace paddle
212
+
213
+ REGISTER_MIR_PASS (__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass,
214
+ paddle::lite::mir::XPUMultiEncoderAdaptiveSeqlenV2FusePass)
215
+ .BindTargets({TARGET (kXPU )});
0 commit comments