Skip to content

Commit dde77c2

Browse files
committed
[xpu] add c++ test of fused_multi_transformer_xpu_quant_pass
1 parent 28c5b29 commit dde77c2

File tree

7 files changed

+229
-53
lines changed

7 files changed

+229
-53
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ if(WITH_XPU)
235235
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
236236
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
237237
${XPU_PASS_DEPS})
238-
pass_library(fused_multi_transformer_quant_pass inference DIR xpu DEPS
238+
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
239239
${XPU_PASS_DEPS})
240240
endif()
241241

@@ -495,4 +495,8 @@ if(WITH_XPU)
495495
test_delete_isolated_node_pass
496496
SRCS xpu/delete_isolated_node_pass_test.cc
497497
DEPS delete_isolated_node_pass)
498+
cc_test(
499+
test_fused_multi_transformer_xpu_quant_pass
500+
SRCS xpu/fused_multi_transformer_xpu_quant_pass_tester.cc
501+
DEPS fused_multi_transformer_xpu_quant_pass)
498502
endif()

paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass_tester.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
7575
1,
7676
{2, -1, 16, 1024, 64},
7777
0);
78-
auto* out = layers.fused_multi_transformer(x,
78+
auto outs = layers.fused_multi_transformer(x,
7979
cache_kv,
8080
src_mask,
8181
qkv_w,
@@ -93,7 +93,7 @@ TEST(FuseMultiTransformerLayerPass, encoder_fp) {
9393
0.1,
9494
1e-12);
9595

96-
x = out;
96+
x = outs[0];
9797
}
9898
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
9999
graph->Set("__param_scope__", CreateParamScope());
@@ -126,7 +126,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
126126
for (int i = 0; i < num_layers; ++i) {
127127
auto* shape_out = layers.shape(src_mask);
128128
auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4});
129-
auto* out = layers.fused_multi_transformer(x,
129+
auto outs = layers.fused_multi_transformer(x,
130130
cache_kv,
131131
src_mask,
132132
qkv_w,
@@ -145,7 +145,7 @@ TEST(FuseMultiTransformerLayerPass, decoder_fp) {
145145
1e-12,
146146
time_stamp);
147147

148-
x = out;
148+
x = outs[0];
149149
}
150150
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
151151
auto param_scope = CreateParamScope();

paddle/fluid/framework/ir/pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
4949
"fuse_multi_transformer_layer_pass",
5050
"delete_quant_dequant_linear_op_pass",
5151
"delete_weight_dequant_linear_op_pass",
52-
"fused_multi_transformer_quant_pass",
52+
"fused_multi_transformer_xpu_quant_pass",
5353
"fc_xpu_fuse_pass",
5454
"delete_op_device_pass"};
5555

paddle/fluid/framework/ir/pass_tester_helper.h

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -571,33 +571,35 @@ struct Layers {
571571
return out;
572572
}
573573

574-
VarDesc* fused_multi_transformer(VarDesc* x,
575-
VarDesc* cache_kv,
576-
VarDesc* src_mask,
577-
VarDesc* qkv_w,
578-
VarDesc* qkv_bias,
579-
VarDesc* out_linear_w,
580-
VarDesc* out_linear_bias,
581-
VarDesc* ffn1_w,
582-
VarDesc* ffn1_bias,
583-
VarDesc* ffn2_w,
584-
VarDesc* ffn2_bias,
585-
VarDesc* ln_scale,
586-
VarDesc* ln_bias,
587-
VarDesc* ffn_ln_scale,
588-
VarDesc* ffn_ln_bias,
589-
float epsilon,
590-
float dropout_rate,
591-
VarDesc* time_stamp = nullptr,
592-
VarDesc* qkv_out_scale = nullptr,
593-
VarDesc* out_linear_out_scale = nullptr,
594-
VarDesc* ffn1_out_scale = nullptr,
595-
VarDesc* ffn2_out_scale = nullptr,
596-
std::vector<float> qkv_in_scale = {},
597-
std::vector<float> out_linear_in_scale = {},
598-
std::vector<float> ffn1_in_scale = {},
599-
std::vector<float> ffn2_in_scale = {}) {
574+
std::vector<VarDesc*> fused_multi_transformer(
575+
VarDesc* x,
576+
VarDesc* cache_kv,
577+
VarDesc* src_mask,
578+
VarDesc* qkv_w,
579+
VarDesc* qkv_bias,
580+
VarDesc* out_linear_w,
581+
VarDesc* out_linear_bias,
582+
VarDesc* ffn1_w,
583+
VarDesc* ffn1_bias,
584+
VarDesc* ffn2_w,
585+
VarDesc* ffn2_bias,
586+
VarDesc* ln_scale,
587+
VarDesc* ln_bias,
588+
VarDesc* ffn_ln_scale,
589+
VarDesc* ffn_ln_bias,
590+
float epsilon,
591+
float dropout_rate,
592+
VarDesc* time_stamp = nullptr,
593+
VarDesc* qkv_out_scale = nullptr,
594+
VarDesc* out_linear_out_scale = nullptr,
595+
VarDesc* ffn1_out_scale = nullptr,
596+
VarDesc* ffn2_out_scale = nullptr,
597+
std::vector<float> qkv_in_scale = {},
598+
std::vector<float> out_linear_in_scale = {},
599+
std::vector<float> ffn1_in_scale = {},
600+
std::vector<float> ffn2_in_scale = {}) {
600601
VarDesc* out = lod_tensor(unique_name());
602+
VarDesc* cache_kv_out = lod_tensor(unique_name());
601603
OpDesc* op = program_.MutableBlock(0)->AppendOp();
602604
std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8"
603605
: "fused_multi_transformer";
@@ -623,6 +625,7 @@ struct Layers {
623625
op->SetAttr("dropout_rate", dropout_rate);
624626
op->SetAttr("epsilon", epsilon);
625627
op->SetOutput("Out", {out->Name()});
628+
op->SetOutput("CacheKVOut", {cache_kv_out->Name()});
626629

627630
if (time_stamp) {
628631
op->SetInput("TimeStep", {time_stamp->Name()});
@@ -638,7 +641,8 @@ struct Layers {
638641
op->SetAttr("ffn1_in_scale", ffn1_in_scale);
639642
op->SetAttr("ffn2_in_scale", ffn2_in_scale);
640643
}
641-
return out;
644+
std::vector<VarDesc*> outs = {out, cache_kv_out};
645+
return outs;
642646
}
643647

644648
VarDesc* dequantize_linear(VarDesc* x,

paddle/fluid/framework/ir/xpu/fused_multi_transformer_quant_pass.cc renamed to paddle/fluid/framework/ir/xpu/fused_multi_transformer_xpu_quant_pass.cc

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ FusedMultiTransformerPattern::FusedMultiTransformerPattern(
250250
1. transpose and quantify the weights of fused_multi_transformer op from fp32 to
251251
int16
252252
*/
253-
class FusedMultiTransformerQuantPass : public FusePassBase {
253+
class FusedMultiTransformerXPUQuantPass : public FusePassBase {
254254
protected:
255255
void ApplyImpl(ir::Graph* graph) const override;
256256

@@ -263,32 +263,30 @@ class FusedMultiTransformerQuantPass : public FusePassBase {
263263
bool with_seq_lengths,
264264
bool with_src_mask) const;
265265

266-
const std::string name_scope_{"fused_multi_transformer_quant_pass"};
266+
const std::string name_scope_{"fused_multi_transformer_xpu_quant_pass"};
267267
};
268268

269-
void FusedMultiTransformerQuantPass::ApplyImpl(ir::Graph* graph) const {
269+
void FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph) const {
270270
PADDLE_ENFORCE_NOT_NULL(
271271
graph, platform::errors::PreconditionNotMet("graph should not be null."));
272272
Init(name_scope_, graph);
273-
VLOG(3) << "DEBUG: in FusedMultiTransformerQuantPass::ApplyImpl";
273+
VLOG(3) << "in FusedMultiTransformerXPUQuantPass::ApplyImpl";
274274

275275
int found_subgraph_count = 0;
276-
for (bool with_cache_kv : {true, false}) {
277-
for (bool with_time_step : {true, false}) {
278-
found_subgraph_count += ApplyImpl(
279-
graph, with_cache_kv, false, false, with_time_step, false, true);
280-
}
276+
for (bool with_time_step : {true, false}) {
277+
found_subgraph_count +=
278+
ApplyImpl(graph, true, false, false, with_time_step, false, true);
281279
}
282280
AddStatis(found_subgraph_count);
283281
}
284282

285-
int FusedMultiTransformerQuantPass::ApplyImpl(ir::Graph* graph,
286-
bool with_cache_kv,
287-
bool with_pre_caches,
288-
bool with_rotary_pos_emb,
289-
bool with_time_step,
290-
bool with_seq_lengths,
291-
bool with_src_mask) const {
283+
int FusedMultiTransformerXPUQuantPass::ApplyImpl(ir::Graph* graph,
284+
bool with_cache_kv,
285+
bool with_pre_caches,
286+
bool with_rotary_pos_emb,
287+
bool with_time_step,
288+
bool with_seq_lengths,
289+
bool with_src_mask) const {
292290
GraphPatternDetector gpd;
293291
patterns::FusedMultiTransformerPattern pattern(gpd.mutable_pattern(),
294292
name_scope_,
@@ -302,7 +300,7 @@ int FusedMultiTransformerQuantPass::ApplyImpl(ir::Graph* graph,
302300
int found_subgraph_count = 0;
303301
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
304302
Graph* graph) {
305-
VLOG(4) << "handle FusedMultiTransformerQuantPass fuse";
303+
VLOG(4) << "handle FusedMultiTransformerXPUQuantPass fuse";
306304

307305
GET_IR_NODE(x);
308306
GET_IR_NODE(ln_scale);
@@ -544,5 +542,5 @@ int FusedMultiTransformerQuantPass::ApplyImpl(ir::Graph* graph,
544542
} // namespace framework
545543
} // namespace paddle
546544

547-
REGISTER_PASS(fused_multi_transformer_quant_pass,
548-
paddle::framework::ir::FusedMultiTransformerQuantPass);
545+
REGISTER_PASS(fused_multi_transformer_xpu_quant_pass,
546+
paddle::framework::ir::FusedMultiTransformerXPUQuantPass);
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include <gtest/gtest.h>
13+
14+
#include "paddle/fluid/framework/ir/pass.h"
15+
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
16+
17+
#define DEF_INPUT_DATA \
18+
Layers layers; \
19+
auto* x = layers.data("x", {1, 128, 1024}); \
20+
auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \
21+
auto* ln_scale = layers.data("ln_scale", {1024}, true); \
22+
auto* ln_bias = layers.data("ln_bias", {1024}, true); \
23+
auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \
24+
auto* qkv_bias = layers.data("qkv_bias", {3, 16, 64}, true); \
25+
auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \
26+
auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \
27+
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \
28+
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \
29+
auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \
30+
auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \
31+
auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \
32+
auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);
33+
34+
namespace paddle {
35+
namespace framework {
36+
namespace ir {
37+
38+
void AddVarToScope(Scope* param_scope,
39+
const std::string& name,
40+
const DDim& dims) {
41+
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
42+
tensor->Resize(dims);
43+
tensor->mutable_data<float>(platform::CPUPlace());
44+
}
45+
46+
Scope* CreateParamScope() {
47+
auto param_scope = new Scope();
48+
AddVarToScope(param_scope, "ln_scale", {1024});
49+
AddVarToScope(param_scope, "ln_bias", {1024});
50+
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
51+
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
52+
53+
AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024});
54+
AddVarToScope(param_scope, "out_linear_w", {1024, 1024});
55+
AddVarToScope(param_scope, "ffn1_w", {1024, 4096});
56+
AddVarToScope(param_scope, "ffn2_w", {4096, 1024});
57+
AddVarToScope(param_scope, "qkv_bias", {3072});
58+
AddVarToScope(param_scope, "out_linear_bias", {1024});
59+
AddVarToScope(param_scope, "ffn1_bias", {4096});
60+
AddVarToScope(param_scope, "ffn2_bias", {1024});
61+
62+
return param_scope;
63+
}
64+
65+
TEST(FusedMultiTransformerXPUQuantPass, context_stage) {
66+
DEF_INPUT_DATA
67+
68+
auto* cache_kv = layers.fill_constant_batch_size_like(
69+
x,
70+
static_cast<int>(proto::VarType::FP32),
71+
0,
72+
1,
73+
{2, -1, 16, 1024, 64},
74+
0);
75+
76+
layers.fused_multi_transformer(x,
77+
cache_kv,
78+
src_mask,
79+
qkv_w,
80+
qkv_bias,
81+
out_linear_w,
82+
out_linear_bias,
83+
ffn1_w,
84+
ffn1_bias,
85+
ffn2_w,
86+
ffn2_bias,
87+
ln_scale,
88+
ln_bias,
89+
ffn_ln_scale,
90+
ffn_ln_bias,
91+
0.1,
92+
1e-12);
93+
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
94+
graph->Set("__param_scope__", CreateParamScope());
95+
96+
auto pass =
97+
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
98+
if (pass.get() == nullptr) {
99+
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
100+
}
101+
102+
graph.reset(pass->Apply(graph.release()));
103+
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu");
104+
VLOG(3) << DebugString(graph);
105+
106+
PADDLE_ENFORCE_EQ(
107+
num_nodes_after,
108+
1,
109+
platform::errors::InvalidArgument(
110+
"After the fuse_multi_transformer_layer_pass, "
111+
"The node num in graph should be 1, but the result is %d",
112+
num_nodes_after));
113+
}
114+
115+
TEST(FusedMultiTransformerXPUQuantPass, decoder_stage) {
116+
DEF_INPUT_DATA
117+
118+
auto* cache_kv = layers.fill_constant_batch_size_like(
119+
x,
120+
static_cast<int>(proto::VarType::FP32),
121+
0,
122+
1,
123+
{2, -1, 16, 1024, 64},
124+
0);
125+
auto* time_step = layers.data("time_step", {1});
126+
layers.fused_multi_transformer(x,
127+
cache_kv,
128+
src_mask,
129+
qkv_w,
130+
qkv_bias,
131+
out_linear_w,
132+
out_linear_bias,
133+
ffn1_w,
134+
ffn1_bias,
135+
ffn2_w,
136+
ffn2_bias,
137+
ln_scale,
138+
ln_bias,
139+
ffn_ln_scale,
140+
ffn_ln_bias,
141+
0.1,
142+
1e-12,
143+
time_step);
144+
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
145+
graph->Set("__param_scope__", CreateParamScope());
146+
147+
auto pass =
148+
PassRegistry::Instance().Get("fused_multi_transformer_xpu_quant_pass");
149+
if (pass.get() == nullptr) {
150+
LOG(INFO) << "get fused_multi_transformer_xpu_quant_pass failed";
151+
}
152+
153+
graph.reset(pass->Apply(graph.release()));
154+
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer_xpu");
155+
VLOG(3) << DebugString(graph);
156+
157+
PADDLE_ENFORCE_EQ(
158+
num_nodes_after,
159+
1,
160+
platform::errors::InvalidArgument(
161+
"After the fuse_multi_transformer_layer_pass, "
162+
"The node num in graph should be 1, but the result is %d",
163+
num_nodes_after));
164+
}
165+
166+
} // namespace ir
167+
} // namespace framework
168+
} // namespace paddle
169+
170+
USE_PASS(fused_multi_transformer_xpu_quant_pass);

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
524524
"embedding_with_eltwise_add_xpu_fuse_pass",
525525
"multi_encoder_xpu_fuse_pass",
526526
"multi_encoder_xpu_slice_fuse_pass",
527-
"fused_multi_transformer_quant_pass",
527+
"fused_multi_transformer_xpu_quant_pass",
528528
"fc_xpu_fuse_pass",
529529
"link_xpu_op_max_pass",
530530
"delete_op_device_pass",

0 commit comments

Comments
 (0)