Skip to content

Commit d65a2b5

Browse files
linwei210qfyinbd
authored andcommitted
[xpu] multi_encoder supports no mask input, such as VIT (PaddlePaddle#9712)
1 parent 8b0643f commit d65a2b5

File tree

3 files changed

+116
-41
lines changed

3 files changed

+116
-41
lines changed

lite/core/optimizer/mir/fusion/__xpu__multi_encoder_fuse_pass.cc

Lines changed: 70 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,22 @@ class XPUSingleEncoderFuser : public FuseBase {
4848
const std::string& input_pos = "Y",
4949
const std::string& qkv_ln_2_out_pos = "Y",
5050
const std::string& matmul_type = "matmul",
51+
const std::string& matmul2_type = "matmul_v2",
5152
const std::string& mul_type = "mul",
5253
bool with_q_scale = true,
5354
bool norm_before = false,
54-
const std::string& relative_type = "")
55+
const std::string& relative_type = "",
56+
bool with_mask = true)
5557
: act_type_(act_type),
5658
input_pos_(input_pos),
5759
qkv_ln_2_out_pos_(qkv_ln_2_out_pos),
5860
matmul_type_(matmul_type),
61+
matmul2_type_(matmul2_type),
5962
mul_type_(mul_type),
6063
with_q_scale_(with_q_scale),
6164
norm_before_(norm_before),
62-
relative_emb_type_(relative_type) {}
65+
relative_emb_type_(relative_type),
66+
with_mask_(with_mask) {}
6367

6468
void BuildPattern() override {
6569
auto* input = VarNode("input")
@@ -213,18 +217,25 @@ class XPUSingleEncoderFuser : public FuseBase {
213217
->AsIntermediate();
214218

215219
auto* qk_matmul = OpNode("qk_matmul", matmul_type_)->AsIntermediate();
220+
std::string op_after_qk_matmul = with_mask_ ? "elementwise_add" : "softmax";
216221
auto* qk_matmul_out = VarNode("qk_matmul_out")
217222
->assert_is_op_output(matmul_type_, "Out")
218-
->assert_is_op_input("elementwise_add", "X")
223+
->assert_is_op_input(op_after_qk_matmul, "X")
219224
->AsIntermediate();
220-
auto* qk_mask = VarNode("qk_mask")
221-
->assert_is_op_input("elementwise_add", "Y")
222-
->AsInput();
223-
auto* qk_add = OpNode("qk_add", "elementwise_add")->AsIntermediate();
224-
auto* qk_add_out = VarNode("qk_add_out")
225-
->assert_is_op_output("elementwise_add", "Out")
226-
->assert_is_op_input("softmax", "X")
227-
->AsIntermediate();
225+
PMNode* qk_mask = nullptr;
226+
PMNode* qk_add = nullptr;
227+
PMNode* qk_add_out = nullptr;
228+
if (with_mask_) {
229+
qk_mask = VarNode("qk_mask")
230+
->assert_is_op_input("elementwise_add", "Y")
231+
->AsInput();
232+
qk_add = OpNode("qk_add", "elementwise_add")->AsIntermediate();
233+
qk_add_out = VarNode("qk_add_out")
234+
->assert_is_op_output("elementwise_add", "Out")
235+
->assert_is_op_input("softmax", "X")
236+
->AsIntermediate();
237+
}
238+
228239
auto* qk_softmax = OpNode("qk_softmax", "softmax")->AsIntermediate();
229240
auto* qk_softmax_out = VarNode("qk_softmax_out")
230241
->assert_is_op_output("softmax", "Out")
@@ -256,16 +267,16 @@ class XPUSingleEncoderFuser : public FuseBase {
256267
auto* v_transpose2 = OpNode("v_transpose2", "transpose2")->AsIntermediate();
257268
auto* v_transpose2_out = VarNode("v_transpose2_out")
258269
->assert_is_op_output("transpose2", "Out")
259-
->assert_is_op_input(matmul_type_, "Y")
270+
->assert_is_op_input(matmul2_type_, "Y")
260271
->AsIntermediate();
261272
auto* v_transpose2_xshape =
262273
VarNode("v_transpose2_xshape")
263274
->assert_is_op_output("transpose2", "XShape")
264275
->AsIntermediate();
265276

266-
auto* qkv_matmul = OpNode("qkv_matmul", matmul_type_)->AsIntermediate();
277+
auto* qkv_matmul = OpNode("qkv_matmul", matmul2_type_)->AsIntermediate();
267278
auto* qkv_matmul_out = VarNode("qkv_matmul_out")
268-
->assert_is_op_output(matmul_type_, "Out")
279+
->assert_is_op_output(matmul2_type_, "Out")
269280
->assert_is_op_input("transpose2", "X")
270281
->AsIntermediate();
271282
auto* qkv_transpose2 =
@@ -459,9 +470,14 @@ class XPUSingleEncoderFuser : public FuseBase {
459470
*k_reshape2 >> *k_reshape2_xshape;
460471
*k_transpose2 >> *k_transpose2_xshape;
461472

462-
*qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
463-
*qk_softmax_out >> *qkv_matmul;
464-
*qk_mask >> *qk_add;
473+
if (with_mask_) {
474+
*qk_matmul >> *qk_matmul_out >> *qk_add >> *qk_add_out >> *qk_softmax >>
475+
*qk_softmax_out >> *qkv_matmul;
476+
*qk_mask >> *qk_add;
477+
} else {
478+
*qk_matmul >> *qk_matmul_out >> *qk_softmax >> *qk_softmax_out >>
479+
*qkv_matmul;
480+
}
465481

466482
if (norm_before_) {
467483
*ln_before_out >> *v_mul;
@@ -513,7 +529,9 @@ class XPUSingleEncoderFuser : public FuseBase {
513529
cpp::OpDesc op_desc;
514530
op_desc.SetType("single_encoder");
515531
op_desc.SetInput("Inputs", {matched.at("input")->arg()->name});
516-
op_desc.SetInput("Mask", {matched.at("qk_mask")->arg()->name});
532+
if (with_mask_) {
533+
op_desc.SetInput("Mask", {matched.at("qk_mask")->arg()->name});
534+
}
517535
op_desc.SetInput("FCWeight",
518536
{
519537
matched.at("q_mul_y")->arg()->name,
@@ -645,7 +663,6 @@ class XPUSingleEncoderFuser : public FuseBase {
645663
single_encoder_stmt->SetOp(fake_subgraph_op);
646664

647665
std::vector<std::string> froms = {
648-
"qk_mask",
649666
"k_mul_y",
650667
"v_mul_y",
651668
"qkv_mul_y",
@@ -660,6 +677,9 @@ class XPUSingleEncoderFuser : public FuseBase {
660677
"qkv_ln_2_scale",
661678
"qkv_ln_2_bias",
662679
};
680+
if (with_mask_) {
681+
froms.push_back("qk_mask");
682+
}
663683
if (relative_emb_type_ == "__xpu__roformer_relative_embedding") {
664684
froms.push_back("q_cos_embedding");
665685
froms.push_back("q_sin_embedding");
@@ -687,10 +707,12 @@ class XPUSingleEncoderFuser : public FuseBase {
687707
std::string input_pos_;
688708
std::string qkv_ln_2_out_pos_;
689709
std::string matmul_type_;
710+
std::string matmul2_type_;
690711
std::string mul_type_;
691712
bool with_q_scale_;
692713
bool norm_before_;
693714
const std::string relative_emb_type_;
715+
bool with_mask_;
694716
// quant_info: mul input_max, output_max * 6 + matmul x_max:y_max, output_max
695717
// * 2
696718
void set_quant_info(Scope* scope,
@@ -955,7 +977,7 @@ class XPUMultiEncoderFuser {
955977
std::string mask_name;
956978
for (auto* encoder : all_encoders) {
957979
auto* op_info = encoder->stmt()->op_info();
958-
if (mask_name.empty()) {
980+
if (mask_name.empty() && op_info->HasInput("Mask")) {
959981
mask_name = op_info->Input("Mask").front();
960982
} else {
961983
// CHECK(mask_name == op_info->Input("Mask").front());
@@ -1026,13 +1048,11 @@ class XPUMultiEncoderFuser {
10261048
if (all_encoders.size() == 1) {
10271049
// take care of only one encoder
10281050
in_name = op_info->Input("Inputs").front();
1029-
mask_name = op_info->Input("Mask").front();
10301051
out_name = op_info->Output("Outputs").front();
10311052
} else if (i == 0) {
10321053
// first encoder
10331054
to_remove.insert(cur_out);
10341055
in_name = op_info->Input("Inputs").front();
1035-
mask_name = op_info->Input("Mask").front();
10361056
} else if (i == all_encoders.size() - 1) {
10371057
// last encoder
10381058
to_remove.insert(cur_encoder);
@@ -1059,7 +1079,9 @@ class XPUMultiEncoderFuser {
10591079
for (auto kv : arg_map) {
10601080
op_desc.SetInput(kv.first, kv.second);
10611081
}
1062-
op_desc.SetInput("Mask", {mask_name});
1082+
if (!mask_name.empty()) {
1083+
op_desc.SetInput("Mask", {mask_name});
1084+
}
10631085
op_desc.SetOutput("Output", {out_name});
10641086
op_desc.SetAttr<int>("xpu", 1);
10651087
op_desc.SetAttr<int>(
@@ -1404,9 +1426,11 @@ class XPUMultiEncoderFusePass : public ProgramPass {
14041426
std::vector<std::string> input_poss{"X", "Y"};
14051427
std::vector<std::string> qkv_ln_2_out_poss{"X", "Y"};
14061428
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
1429+
std::vector<std::string> matmul2_types{"matmul", "matmul_v2"};
14071430
std::vector<std::string> mul_types{"mul", "matmul", "matmul_v2"};
14081431
std::vector<bool> with_q_scales{true, false};
14091432
std::vector<bool> norm_befores{true, false};
1433+
std::vector<bool> with_mask{true, false};
14101434
std::vector<std::string> relative_embedding_type{
14111435
"", "__xpu__roformer_relative_embedding"};
14121436

@@ -1445,23 +1469,29 @@ class XPUMultiEncoderFusePass : public ProgramPass {
14451469
for (auto& input_pos : input_poss) {
14461470
for (auto& qkv_ln_2_out_pos : qkv_ln_2_out_poss) {
14471471
for (auto& matmul_type : matmul_types) {
1448-
for (auto& mul_type : mul_types) {
1449-
for (auto with_q_scale : with_q_scales) {
1450-
for (auto norm_before : norm_befores) {
1451-
for (auto relative_type : relative_embedding_type) {
1452-
fusion::XPUSingleEncoderFuser single_encoder_fuser(
1453-
act_type,
1454-
input_pos,
1455-
qkv_ln_2_out_pos,
1456-
matmul_type,
1457-
mul_type,
1458-
with_q_scale,
1459-
norm_before,
1460-
relative_type);
1461-
single_encoder_fuser(graph.get());
1462-
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
1463-
fc_precision, adaptive_seqlen);
1464-
multi_encoder_fuser(graph.get());
1472+
for (auto& matmul2_type : matmul2_types) {
1473+
for (auto& mul_type : mul_types) {
1474+
for (auto with_q_scale : with_q_scales) {
1475+
for (auto norm_before : norm_befores) {
1476+
for (auto relative_type : relative_embedding_type) {
1477+
for (auto mask : with_mask) {
1478+
fusion::XPUSingleEncoderFuser single_encoder_fuser(
1479+
act_type,
1480+
input_pos,
1481+
qkv_ln_2_out_pos,
1482+
matmul_type,
1483+
matmul2_type,
1484+
mul_type,
1485+
with_q_scale,
1486+
norm_before,
1487+
relative_type,
1488+
mask);
1489+
single_encoder_fuser(graph.get());
1490+
fusion::XPUMultiEncoderFuser multi_encoder_fuser(
1491+
fc_precision, adaptive_seqlen);
1492+
multi_encoder_fuser(graph.get());
1493+
}
1494+
}
14651495
}
14661496
}
14671497
}

lite/core/optimizer/mir/fusion/__xpu__multi_encoder_slice_link_fuse_pass.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ class XPUMultiEncoderSliceLinkFuser : public FuseBase {
5757
layer_norm = OpNode("layer_norm", "layer_norm");
5858
layer_norm_out = VarNode("layer_norm_out")
5959
->assert_is_op_output("layer_norm", "Y")
60-
->assert_is_op_input("slice", "Input");
60+
->assert_is_op_input("slice", "Input")
61+
->assert_only_one_output();
6162
} else {
6263
xpu_encoder->assert_op_attr<bool>("norm_before", false);
6364
encoder_out->assert_is_op_input("slice", "Input")->AsIntermediate();

lite/kernels/xpu/__xpu__multi_encoder_compute.cc

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,50 @@ void XPUMultiEncoderCompute::run_encoder(const T* in, T* out) {
255255
arg_ln_bias_,
256256
qkv_attn_param);
257257
CHECK_EQ(r, 0);
258+
} else if (param.mask == nullptr) {
259+
// When no mask input, like VIT, create LOD to act as vsl.
260+
int batch = static_cast<int>(param.input->dims()[0]);
261+
int max_seqlen = static_cast<int>(param.input->dims()[1]);
262+
std::vector<int> lod;
263+
for (int i = 0; i < batch + 1; i++) {
264+
lod.push_back(i * max_seqlen);
265+
}
266+
query_lod = {lod.data(), static_cast<int>(lod.size()), nullptr};
267+
// No need to pad, no matter slice or not
268+
int max_pad_seqlen = -1;
269+
xdnn::QKVAttnParam qkv_attn_param(query_lod, /* lod */
270+
param.head_num,
271+
param.size_per_head,
272+
qkv_act,
273+
slice_idx,
274+
true /* qkv fusion */,
275+
max_pad_seqlen,
276+
param.hidden_dim,
277+
param.norm_before, /*is_pre_norm*/
278+
param.per_channel);
279+
qkv_attn_param.quant_type_.assign(quant_types_.begin(), quant_types_.end());
280+
if (relative_type_ == 1) {
281+
qkv_attn_param.relative_type = relative_type_;
282+
qkv_attn_param.max_pos_len = param.max_pos_len;
283+
qkv_attn_param.relative_pos.assign(roformer_embedding_.begin(),
284+
roformer_embedding_.end());
285+
}
286+
qkv_attn_param.scale_of_hidden_units = param.ffn_hidden_dim_scale;
287+
if (std::is_same<TGEMM, int8_t>::value) {
288+
CHECK_GT(fc_input_max_.size(), 0);
289+
}
290+
int r = xdnn::transformer_encoder<T, TW, TGEMM>(
291+
ctx.GetRawContext(),
292+
in,
293+
*(XPUMultiEncoderCompute::get_weight<TW>()),
294+
out,
295+
fc_input_max_,
296+
fc_weight_max_,
297+
arg_fc_bias_,
298+
arg_ln_scale_,
299+
arg_ln_bias_,
300+
qkv_attn_param);
301+
CHECK_EQ(r, 0);
258302
} else {
259303
// no vsl
260304
int batch = static_cast<int>(param.input->dims()[0]);

0 commit comments

Comments
 (0)