Skip to content

Commit c794247

Browse files
authored
[XPU] Add adaptive_seqlen_v2_fuse_pass and add mask_type (#9710)
1 parent 4b337d8 commit c794247

File tree

6 files changed

+234
-2
lines changed

6 files changed

+234
-2
lines changed

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ USE_MIR_PASS(__xpu__conv2d_affine_channel_fuse_pass);
8787
USE_MIR_PASS(__xpu__conv2d_fuse_pass);
8888
USE_MIR_PASS(__xpu__softmax_topk_fuse_pass);
8989
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_fuse_pass);
90+
USE_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass);
9091
USE_MIR_PASS(__xpu__roformer_relative_pos_fuse_pass);
9192
USE_MIR_PASS(__xpu__multi_encoder_slice_link_fuse_pass);
9293
USE_MIR_PASS(__xpu__generate_sequence_fuse_pass);
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <memory>
16+
#include <string>
17+
18+
#include "lite/backends/xpu/math.h"
19+
#include "lite/core/optimizer/mir/pass_registry.h"
20+
#include "lite/core/optimizer/mir/pattern_matcher_high_api.h"
21+
22+
namespace paddle {
23+
namespace lite {
24+
namespace mir {
25+
namespace fusion {
26+
27+
/* support adaptive seq len for bert/ernie */
28+
/* in_Input in_Mask fill_constant */
29+
/* | \ / */
30+
/* | \ / */
31+
/* | | */
32+
/* xpu_embedding equal */
33+
/* | | */
34+
/* | | */
35+
/* layer_norm cast */
36+
/* | | */
37+
/* | scale */
38+
/* | / */
39+
/* | unsqueeze2 */
40+
/* | | */
41+
/* | / */
42+
/* | / */
43+
/* xpu_encoder */
44+
/* | */
45+
/* | */
46+
/* out_Output */
47+
/*---------------------------------------------------*/
48+
/* After the pass apply: */
49+
/* in_Input in_Mask */
50+
/* | | */
51+
/* | | */
52+
/* | / */
53+
/* xpu_embedding */
54+
/* | \ */
55+
/* | SeqLod */
56+
/* | | */
57+
/* layer_norm | */
58+
/* | | */
59+
/* | / */
60+
/* xpu_encoder */
61+
/* | */
62+
/* | */
63+
/* out_Output */
64+
/*---------------------------------------------------*/
65+
66+
class XPUMultiEncoderAdaptiveSeqlenV2Fuser : public FuseBase {
67+
public:
68+
explicit XPUMultiEncoderAdaptiveSeqlenV2Fuser(bool pre_ln = false)
69+
: pre_ln_(pre_ln) {}
70+
71+
void BuildPattern() override {
72+
auto* mask = VarNode("mask")->assert_is_op_input("equal", "X")->AsInput();
73+
auto* fill_constant =
74+
OpNode("fill_constant", "fill_constant")->AsIntermediate();
75+
// delete fill_constant_out
76+
auto* fill_constant_out = VarNode("fill_constant_out")
77+
->assert_is_op_output("fill_constant", "Out")
78+
->assert_is_op_input("equal", "Y")
79+
->AsIntermediate();
80+
auto* equal = OpNode("equal", "equal")->AsIntermediate();
81+
auto* equal_out = VarNode("equal_out")
82+
->assert_is_op_output("equal", "Out")
83+
->assert_is_op_input("cast", "X")
84+
->AsIntermediate();
85+
auto* cast = OpNode("cast", "cast")->AsIntermediate();
86+
auto* cast_out = VarNode("cast_out")
87+
->assert_is_op_output("cast", "Out")
88+
->assert_is_op_input("scale", "X")
89+
->AsIntermediate();
90+
auto* scale = OpNode("scale", "scale")->AsIntermediate();
91+
auto* scale_out = VarNode("scale_out")
92+
->assert_is_op_output("scale", "Out")
93+
->assert_is_op_input("unsqueeze2", "X")
94+
->AsIntermediate();
95+
auto* unsqueeze2 = OpNode("unsqueeze2", "unsqueeze2")->AsIntermediate();
96+
auto* unsqueeze2_out =
97+
VarNode("unsqueeze2_out")
98+
->assert_is_op_output("unsqueeze2", "Out")
99+
->assert_is_op_input("__xpu__multi_encoder", "Mask")
100+
->AsIntermediate();
101+
// delete unsqueeze2_out_xshape
102+
auto* unsqueeze2_out_xshape =
103+
VarNode("unsqueeze2_out_xshape")
104+
->assert_is_op_output("unsqueeze2", "XShape")
105+
->AsIntermediate();
106+
auto* xpu_embedding =
107+
OpNode("xpu_embedding", "__xpu__embedding_with_eltwise_add");
108+
109+
PMNode* embedding_out = nullptr;
110+
PMNode* layer_norm = nullptr;
111+
PMNode* layer_norm_out = nullptr;
112+
113+
if (pre_ln_) {
114+
embedding_out = VarNode("embedding_out")
115+
->assert_is_op_output(
116+
"__xpu__embedding_with_eltwise_add", "Output")
117+
->assert_is_op_input("__xpu__multi_encoder", "Input");
118+
} else {
119+
embedding_out = VarNode("embedding_out")
120+
->assert_is_op_output(
121+
"__xpu__embedding_with_eltwise_add", "Output")
122+
->assert_is_op_input("layer_norm", "X");
123+
layer_norm = OpNode("layer_norm", "layer_norm");
124+
layer_norm_out =
125+
VarNode("layer_norm_out")
126+
->assert_is_op_output("layer_norm", "Y")
127+
->assert_is_op_input("__xpu__multi_encoder", "Input");
128+
}
129+
auto* xpu_encoder = OpNode("xpu_encoder", "__xpu__multi_encoder")
130+
->assert_op_attr<bool>("adaptive_seqlen", true);
131+
if (pre_ln_) {
132+
xpu_encoder->assert_op_attr<bool>("norm_before", true);
133+
*xpu_embedding >> *embedding_out >> *xpu_encoder;
134+
} else {
135+
*xpu_embedding >> *embedding_out >> *layer_norm >> *layer_norm_out >>
136+
*xpu_encoder;
137+
}
138+
*mask >> *equal;
139+
*fill_constant >> *fill_constant_out >> *equal;
140+
*equal >> *equal_out >> *cast >> *cast_out >> *scale >> *scale_out >>
141+
*unsqueeze2 >> *unsqueeze2_out >> *xpu_encoder;
142+
*unsqueeze2 >> *unsqueeze2_out_xshape;
143+
}
144+
145+
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
146+
auto* embedding_instruct = matched.at("xpu_embedding")->stmt();
147+
auto embedding_op_desc = *embedding_instruct->mutable_op_info();
148+
auto embedding_op = embedding_instruct->op();
149+
auto* scope = embedding_op->scope();
150+
auto* encoder_instruct = matched.at("xpu_encoder")->stmt();
151+
auto encoder_op_desc = *encoder_instruct->mutable_op_info();
152+
auto encoder_op = encoder_instruct->op();
153+
154+
// add new arg seq_lod
155+
std::string embedding_out_name = matched.at("embedding_out")->arg()->name;
156+
std::string embedding_seq_lod_name = embedding_out_name + "_seq_lod";
157+
auto* embedding_seq_lod_node =
158+
graph->NewArgumentNode(embedding_seq_lod_name);
159+
embedding_seq_lod_node->arg()->type = LiteType::GetTensorTy(
160+
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kNCHW));
161+
scope->NewTensor(embedding_seq_lod_name);
162+
// add new arg pad_seq_len
163+
std::string embedding_pad_seq_len_name =
164+
embedding_out_name + "_pad_seq_len";
165+
auto* embedding_pad_seq_len_node =
166+
graph->NewArgumentNode(embedding_pad_seq_len_name);
167+
embedding_pad_seq_len_node->arg()->type = LiteType::GetTensorTy(
168+
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kNCHW));
169+
scope->NewTensor(embedding_pad_seq_len_name);
170+
171+
embedding_op_desc.SetOutput("SeqLod", {embedding_seq_lod_name});
172+
embedding_op_desc.SetOutput("PadSeqLen", {embedding_pad_seq_len_name});
173+
encoder_op_desc.SetInput("SeqLod", {embedding_seq_lod_name});
174+
encoder_op_desc.SetInput("PadSeqLen", {embedding_pad_seq_len_name});
175+
embedding_op_desc.SetInput("Mask", {matched.at("mask")->arg()->name});
176+
// add mask dtype
177+
embedding_op_desc.SetAttr<int>(
178+
"mask_dtype", static_cast<int>(VarDescAPI::VarDataType::INT64));
179+
embedding_instruct->ResetOp(embedding_op_desc,
180+
embedding_op->valid_places());
181+
encoder_instruct->ResetOp(encoder_op_desc, encoder_op->valid_places());
182+
DirectedLink(matched.at("xpu_embedding"), embedding_seq_lod_node);
183+
DirectedLink(matched.at("xpu_embedding"), embedding_pad_seq_len_node);
184+
DirectedLink(matched.at("mask"), matched.at("xpu_embedding"));
185+
DirectedLink(embedding_seq_lod_node, matched.at("xpu_encoder"));
186+
DirectedLink(embedding_pad_seq_len_node, matched.at("xpu_encoder"));
187+
}
188+
189+
private:
190+
bool pre_ln_;
191+
};
192+
193+
} // namespace fusion
194+
195+
class XPUMultiEncoderAdaptiveSeqlenV2FusePass : public ProgramPass {
196+
public:
197+
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
198+
std::vector<bool> pre_lns{true, false};
199+
for (auto pre_ln : pre_lns) {
200+
fusion::XPUMultiEncoderAdaptiveSeqlenV2Fuser fuser(pre_ln);
201+
fuser(graph.get());
202+
}
203+
}
204+
};
205+
206+
} // namespace mir
207+
} // namespace lite
208+
} // namespace paddle
209+
210+
REGISTER_MIR_PASS(__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass,
211+
paddle::lite::mir::XPUMultiEncoderAdaptiveSeqlenV2FusePass)
212+
.BindTargets({TARGET(kXPU)});

lite/core/optimizer/optimizer.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ std::unique_ptr<RuntimeProgram> RunDefaultOptimizer(
205205
"__xpu__fc_fuse_pass",
206206
"__xpu__softmax_topk_fuse_pass",
207207
"__xpu__multi_encoder_adaptive_seqlen_fuse_pass",
208+
"__xpu__multi_encoder_adaptive_seqlen_v2_fuse_pass",
208209
"__xpu__multi_encoder_slice_link_fuse_pass",
209210
"__xpu__generate_sequence_fuse_pass",
210211
"__xpu__logit_fuse_pass",

lite/kernels/xpu/__xpu__embedding_with_eltwise_add_compute.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,23 @@ void XPUEmbeddingWithEltwiseAddCompute::Run() {
6161
auto* seq_lod = param.SeqLod;
6262
seq_lod->Resize({batch_size + 1});
6363
std::vector<int> cpu_seq_lod{0};
64-
auto* mask_ptr = param.Mask->data<float>();
64+
65+
const void* mask_ptr = nullptr;
66+
if (param.mask_dtype == static_cast<int>(VarDescAPI::VarDataType::INT64)) {
67+
mask_ptr = param.Mask->data<int64_t>();
68+
} else {
69+
mask_ptr = param.Mask->data<float>();
70+
}
71+
6572
for (auto batch_idx = 0; batch_idx < batch_size; batch_idx++) {
6673
int cur_batch_seq_len = 0;
6774
for (auto seq_idx = 0; seq_idx < pad_seq_len; seq_idx++) {
68-
if (mask_ptr[batch_idx * pad_seq_len + seq_idx] > 1e-7) {
75+
if ((param.mask_dtype ==
76+
static_cast<int>(VarDescAPI::VarDataType::INT64) &&
77+
reinterpret_cast<const int64_t*>(
78+
mask_ptr)[batch_idx * pad_seq_len + seq_idx] > 0) ||
79+
reinterpret_cast<const float*>(
80+
mask_ptr)[batch_idx * pad_seq_len + seq_idx] > 1e-7) {
6981
cur_batch_seq_len += 1;
7082
} else {
7183
break;

lite/operators/__xpu__embedding_with_eltwise_add_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ bool XPUEmbeddingWithEltwiseAddOp::AttachImpl(const cpp::OpDesc& op_desc,
9292
}
9393
}
9494
}
95+
// find optional mask dtype
96+
if (op_desc.HasAttr("mask_dtype")) {
97+
param_.mask_dtype = op_desc.GetAttr<int>("mask_dtype");
98+
}
99+
95100
std::vector<std::string> output_arg_names = op_desc.OutputArgumentNames();
96101
if (std::find(output_arg_names.begin(), output_arg_names.end(), "SeqLod") !=
97102
output_arg_names.end()) {

lite/operators/op_params.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,7 @@ struct XPUEmbeddingWithEltwiseAddParam : ParamBase {
17651765
lite::Tensor* PadSeqLen{nullptr};
17661766
lite::Tensor* Out{nullptr};
17671767
int64_t padding_idx{-1};
1768+
int mask_dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
17681769
};
17691770

17701771
struct XPUFcParam : ParamBase {

0 commit comments

Comments
 (0)