|
16 | 16 |
|
17 | 17 | #include <cmath> |
18 | 18 | #include <string> |
| 19 | +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" |
19 | 20 | #include "paddle/fluid/framework/op_proto_maker.h" |
20 | 21 |
|
21 | 22 | #include "paddle/fluid/framework/op_version_registry.h" |
@@ -67,6 +68,42 @@ MapMatmul2MulPass::MapMatmul2MulPass() { |
67 | 68 | .End(); |
68 | 69 | } |
69 | 70 |
|
| 71 | +MapMatmulv2ToMulPass::MapMatmulv2ToMulPass() { |
| 72 | + AddOpCompat(OpCompat("matmul_v2")) |
| 73 | + .AddInput("X") |
| 74 | + .IsTensor() |
| 75 | + .End() |
| 76 | + .AddInput("Y") |
| 77 | + .IsTensor() |
| 78 | + .End() |
| 79 | + .AddOutput("Out") |
| 80 | + .IsTensor() |
| 81 | + .End() |
| 82 | + .AddAttr("trans_x") |
| 83 | + .IsBoolEQ(false) |
| 84 | + .End() |
| 85 | + .AddAttr("trans_y") |
| 86 | + .IsBoolEQ(false) |
| 87 | + .End(); |
| 88 | + |
| 89 | + AddOpCompat(OpCompat("mul")) |
| 90 | + .AddInput("X") |
| 91 | + .IsTensor() |
| 92 | + .End() |
| 93 | + .AddInput("Y") |
| 94 | + .IsTensor() |
| 95 | + .End() |
| 96 | + .AddOutput("Out") |
| 97 | + .IsTensor() |
| 98 | + .End() |
| 99 | + .AddAttr("x_num_col_dims") |
| 100 | + .IsNumGE(1) |
| 101 | + .End() |
| 102 | + .AddAttr("y_num_col_dims") |
| 103 | + .IsNumEQ(1) |
| 104 | + .End(); |
| 105 | +} |
| 106 | + |
70 | 107 | Flatten2MatmulFusePass::Flatten2MatmulFusePass() { |
71 | 108 | AddOpCompat(OpCompat("matmul")) |
72 | 109 | .AddInput("X") |
@@ -250,6 +287,75 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { |
250 | 287 | AddStatis(found_count); |
251 | 288 | } |
252 | 289 |
|
| 290 | +void MapMatmulv2ToMulPass::ApplyImpl(ir::Graph* graph) const { |
| 291 | + PADDLE_ENFORCE_NOT_NULL( |
| 292 | + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); |
| 293 | + std::string name_scope = "map_matmul_v2_to_mul_pass"; |
| 294 | + FusePassBase::Init(name_scope, graph); |
| 295 | + |
| 296 | + GraphPatternDetector gpd; |
| 297 | + patterns::MatmulV2 matmul_pattern(gpd.mutable_pattern(), name_scope); |
| 298 | + matmul_pattern(); |
| 299 | + |
| 300 | + int found_count = 0; |
| 301 | + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, |
| 302 | + Graph* g) { |
| 303 | + VLOG(4) << "map matmul_v2 to mul"; |
| 304 | + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); |
| 305 | + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); |
| 306 | + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); |
| 307 | + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); |
| 308 | + bool flag = true; |
| 309 | + |
| 310 | + bool trans_x = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_x")); |
| 311 | + bool trans_y = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("trans_y")); |
| 312 | + flag = flag && !trans_x && !trans_y; |
| 313 | + |
| 314 | + std::vector<int64_t> x_shape = matmul_in_x->Var()->GetShape(); |
| 315 | + std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape(); |
| 316 | + size_t x_rank = x_shape.size(); |
| 317 | + size_t y_rank = y_shape.size(); |
| 318 | + flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2; |
| 319 | + |
| 320 | + std::vector<Node*>& next_ops = matmul_out->outputs; |
| 321 | + flag = flag && next_ops.size() == 1 && |
| 322 | + next_ops[0]->Name() == "elementwise_add"; |
| 323 | + |
| 324 | + if (flag) { |
| 325 | + if (!IsCompat(subgraph, g)) { |
| 326 | + LOG(WARNING) << "Pass in op compat failed."; |
| 327 | + return; |
| 328 | + } |
| 329 | + OpDesc desc(matmul_op->Op()->Block()); |
| 330 | + desc.SetType("mul"); |
| 331 | + desc.SetInput("X", {matmul_in_x->Name()}); |
| 332 | + desc.SetInput("Y", {matmul_in_y->Name()}); |
| 333 | + desc.SetOutput("Out", {matmul_out->Name()}); |
| 334 | + desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1)); |
| 335 | + desc.SetAttr("y_num_col_dims", 1); |
| 336 | + if (matmul_op->Op()->HasAttr("enable_int8")) { |
| 337 | + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); |
| 338 | + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); |
| 339 | + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); |
| 340 | + } |
| 341 | + auto mul_node = g->CreateOpNode(&desc); |
| 342 | + IR_NODE_LINK_TO(matmul_in_x, mul_node); |
| 343 | + IR_NODE_LINK_TO(matmul_in_y, mul_node); |
| 344 | + IR_NODE_LINK_TO(mul_node, matmul_out); |
| 345 | + GraphSafeRemoveNodes(graph, {matmul_op}); |
| 346 | + ++found_count; |
| 347 | + |
| 348 | + if (!IsCompat(desc)) { |
| 349 | + LOG(WARNING) << "MapMatmulv2ToMulPass in out mul op compat failed."; |
| 350 | + return; |
| 351 | + } |
| 352 | + } |
| 353 | + }; |
| 354 | + |
| 355 | + gpd(graph, handler); |
| 356 | + AddStatis(found_count); |
| 357 | +} |
| 358 | + |
253 | 359 | void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { |
254 | 360 | PADDLE_ENFORCE_NOT_NULL( |
255 | 361 | graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); |
@@ -567,6 +673,14 @@ REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) |
567 | 673 | .LE("matmul", 1) |
568 | 674 | .EQ("mul", 0)); |
569 | 675 |
|
| 676 | +REGISTER_PASS(map_matmul_v2_to_mul_pass, |
| 677 | + paddle::framework::ir::MapMatmulv2ToMulPass); |
| 678 | +REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass) |
| 679 | + .AddCombination( |
| 680 | + paddle::framework::compatible::OpVersionComparatorCombination() |
| 681 | + .EQ("matmul_v2", 0) |
| 682 | + .EQ("mul", 0)); |
| 683 | + |
570 | 684 | REGISTER_PASS(squeeze2_matmul_fuse_pass, |
571 | 685 | paddle::framework::ir::Squeeze2MatmulFusePass); |
572 | 686 | REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) |
|
0 commit comments