Skip to content

Commit 1794f42

Browse files
authored
[CINN] Implement FusionIters for multi-downstream fusion (#67829)
* refine shardable_axes * add axes info in pattern node * implement FusionIters and FuseItersForTrivialSink * implement SingleDownstreamItersFusion * update
1 parent 01be154 commit 1794f42

File tree

13 files changed

+511
-120
lines changed

13 files changed

+511
-120
lines changed

paddle/cinn/operator_fusion/graph_transformer/operation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ struct MergeTrivialPatternOperation {
4444

4545
if (can_fuse) {
4646
auto merged_node = graph->MergeNode(upstream, downstream, MergePattern);
47+
merged_node->set_fusion_iters(
48+
SingleDownstreamItersFusion(upstream, downstream, true));
4749
graph->RemoveNode(downstream);
4850
VLOG(4) << "Spliting trivial pattern: \nupstream "
4951
<< upstream->DebugStr() << "\ndownstream "
@@ -71,6 +73,8 @@ struct MergeReduceTreeOperation {
7173
node->downstream().size()));
7274
auto downstream = node->downstream().at(0);
7375
auto merged_node = graph->MergeNode(node, downstream, MergePattern);
76+
merged_node->set_fusion_iters(
77+
SingleDownstreamItersFusion(node, downstream, true));
7478
graph->RemoveNode(downstream);
7579
graph->RemoveNode(node);
7680
VLOG(4) << "MergeReduceTreeOperation: \nupstream " << node->DebugStr()
@@ -103,6 +107,8 @@ struct MergeReduceTreeAndTrivialOperation {
103107
};
104108
PatternNodePtr merged_node =
105109
graph->MergeNode(node, downstream, merge_pattern_fn);
110+
merged_node->set_fusion_iters(
111+
SingleDownstreamItersFusion(node, downstream, false));
106112
graph->RemoveNode(downstream);
107113
graph->RemoveNode(node);
108114
VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream "

paddle/cinn/operator_fusion/pattern_graph.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ std::vector<PatternNodePtr> PatternGraph::ReturnFusionResults() {
7575
return sorted_nodes;
7676
}
7777

78-
std::vector<PatternNodePtr> PatternGraph::SortByTopoOrder() {
78+
std::vector<PatternNodePtr> PatternGraph::SortByTopoOrder() const {
7979
// sort all_pattern_nodes_ by topo order.
8080
std::vector<PatternNodePtr> res;
8181
std::list<PatternNodePtr> topo_queue;
@@ -100,7 +100,7 @@ std::vector<PatternNodePtr> PatternGraph::SortByTopoOrder() {
100100
return res;
101101
}
102102

103-
std::vector<PatternNodePtr> PatternGraph::SortByReverseTopoOrder() {
103+
std::vector<PatternNodePtr> PatternGraph::SortByReverseTopoOrder() const {
104104
// sort all_pattern_nodes_ by reverse topo order.
105105
std::vector<PatternNodePtr> res;
106106
std::list<PatternNodePtr> reverse_topo_queue;
@@ -215,7 +215,11 @@ PatternGraph::PatternGraph(const std::vector<PatternContent>& contents,
215215
}
216216

217217
for (const auto& content : contents) {
218-
PatternNodePtr node = std::make_shared<PatternNode>(content);
218+
const auto& axes =
219+
policy_manager_.template GetPolicy<RelativeJudgePolicy>()
220+
->GetAxesInfoManager()
221+
.GetModifiedSignature(content.op);
222+
PatternNodePtr node = std::make_shared<PatternNode>(content, axes);
219223
op_to_node_map[content.op] = node;
220224
all_pattern_nodes_.emplace(node);
221225
}
@@ -279,10 +283,12 @@ void PatternGraph::AppendNode(const PatternNodePtr& node) {
279283
std::string PatternGraph::GraphInfo() const {
280284
std::stringstream ss;
281285
ss << "\n========= GraphInfo ===========";
282-
for (const auto& v : all_pattern_nodes_) {
286+
for (const auto& v : SortByTopoOrder()) {
287+
ss << "\n##############################";
283288
ss << "\n" << v->DebugStr();
284289
ss << " IsOutput: " << IsOutputNodeMatcher()(*this, v);
285290
ss << "\n Loop Framework is: " << GetLoopFramework(v->stmt_pattern());
291+
ss << std::endl;
286292
}
287293
ss << "\n===============================";
288294
return ss.str();

paddle/cinn/operator_fusion/pattern_graph.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ class PatternGraph {
4848
PatternNodePtr MergeNode(const PatternNodePtr& upstream,
4949
const PatternNodePtr& downstream,
5050
MergePatternFn merge_pattern_fn);
51-
std::vector<PatternNodePtr> SortByTopoOrder();
52-
std::vector<PatternNodePtr> SortByReverseTopoOrder();
51+
std::vector<PatternNodePtr> SortByTopoOrder() const;
52+
std::vector<PatternNodePtr> SortByReverseTopoOrder() const;
5353

5454
const PatternNodePtrSet& all_pattern_nodes() const {
5555
return all_pattern_nodes_;

paddle/cinn/operator_fusion/pattern_node.h

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/cinn/operator_fusion/pattern.h"
1818
#include "paddle/cinn/operator_fusion/pattern_fuser.h"
19+
#include "paddle/cinn/operator_fusion/pir_graph_analyzing/fusion_iters.h"
1920
#include "paddle/cinn/operator_fusion/utils.h"
2021

2122
namespace cinn::fusion {
@@ -25,8 +26,11 @@ struct PatternNode {
2526
using MergePatternFn =
2627
std::function<StmtPattern(const StmtPattern&, const StmtPattern&)>;
2728

28-
explicit PatternNode(const PatternContent& content)
29-
: sink_op_(content.op), stmt_pattern_(ConvertToStmtPattern(content)) {}
29+
explicit PatternNode(const PatternContent& content,
30+
const ShardableAxesSignature& axes)
31+
: sink_op_(content.op),
32+
stmt_pattern_(ConvertToStmtPattern(content)),
33+
fusion_iters_(FusionItersSignature(content.op, axes)) {}
3034

3135
explicit PatternNode(PatternNodePtr fused_up_node,
3236
PatternNodePtr fused_down_node,
@@ -47,24 +51,25 @@ struct PatternNode {
4751
std::string DebugStr() const {
4852
std::stringstream ss;
4953
ss << "Node: " << this << ", Pattern: " << GetPatternName(stmt_pattern())
50-
<< ", ID: " << GetPatternId(stmt_pattern()) << "\n -u>: ";
54+
<< ", ID: " << GetPatternId(stmt_pattern());
55+
ss << "\n -u>: ";
5156
for (const auto& u : upstream_) {
52-
ss << u << ", ";
57+
ss << GetPatternId(u->stmt_pattern()) << "(" << u << "), ";
5358
}
5459
ss << "\n <d-: ";
5560
for (const auto& d : downstream_) {
56-
ss << d << ", ";
61+
ss << GetPatternId(d->stmt_pattern()) << "(" << d << "), ";
5762
}
63+
ss << "\n" << fusion_iters_.DebugStr();
5864
pir::IrPrinter printer(ss);
5965
if (GetPatternName(stmt_pattern_) == AnchorPattern::name()) {
6066
ss << "\n anchor: ";
6167
auto anchor_op =
6268
std::get<AnchorPattern>(stmt_pattern_).anchor().defining_op();
6369
printer.PrintOperation(const_cast<pir::Operation*>(anchor_op));
6470
}
65-
ss << "\nOps in pattern: \n################" << std::endl;
71+
ss << "\nOps in pattern:" << std::endl;
6672
ss << OpsDebugStr(GetOpsInPattern(this->stmt_pattern()));
67-
ss << "################" << std::endl;
6873
return ss.str();
6974
}
7075

@@ -92,13 +97,20 @@ struct PatternNode {
9297
GetFusionTracker(stmt_pattern_)->append(instr);
9398
}
9499
void UpdateTracker() { PatternUpdateTracker(stmt_pattern_); }
100+
FusionItersSignature fusion_iters() const { return fusion_iters_; }
101+
void set_fusion_iters(const FusionItersSignature& fusion_iters) {
102+
fusion_iters_ = fusion_iters;
103+
VLOG(4) << "set_fusion_iters";
104+
}
95105

96106
private:
97107
StmtPattern stmt_pattern_;
98108
pir::Operation* sink_op_;
99109

100110
std::vector<PatternNodePtr> upstream_;
101111
std::vector<PatternNodePtr> downstream_;
112+
113+
FusionItersSignature fusion_iters_;
102114
};
103115

104116
using PatternNodePtr = std::shared_ptr<PatternNode>;
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
gather_srcs(fusion_graph_analyzing SRCS shardable_axes_base.cc dim_relation.cc
2-
anchor_transform.cc)
2+
anchor_transform.cc fusion_iters.cc)
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/operator_fusion/pir_graph_analyzing/fusion_iters.h"
16+
#include "paddle/cinn/operator_fusion/pattern_node.h"
17+
18+
namespace cinn::fusion {
19+
20+
FusionItersSignature::FusionItersSignature(pir::Operation* op,
21+
const ShardableAxesSignature& axes) {
22+
PADDLE_ENFORCE_EQ(
23+
axes.inputs.size(),
24+
op->num_operands(),
25+
::common::errors::InvalidArgument("The number of input_iters should be "
26+
"equal to the number of operands."));
27+
PADDLE_ENFORCE_EQ(
28+
axes.outputs.size(),
29+
op->num_results(),
30+
::common::errors::InvalidArgument("The number of output_iters should be "
31+
"equal to the number of results."));
32+
loop_iters = axes.loop.axis_names;
33+
for (const auto& iters : axes.inputs) {
34+
input_iters.push_back(iters.axis_names);
35+
}
36+
for (const auto& iters : axes.outputs) {
37+
output_iters.push_back(iters.axis_names);
38+
}
39+
input_values = op->operands_source();
40+
output_values = op->results();
41+
}
42+
43+
std::string PrintFusionIters(const FusionIters& iters) {
44+
std::stringstream ss;
45+
ss << "[ ";
46+
for (const auto& iter : iters) {
47+
ss << iter << ",";
48+
}
49+
return ss.str().substr(0, ss.str().size() - 1) + " ]";
50+
}
51+
52+
std::string FusionItersSignature::DebugStr() const {
53+
std::stringstream ss;
54+
ss << "FusionIters Signature:";
55+
ss << "\n loop : " << PrintFusionIters(loop_iters);
56+
for (size_t i = 0; i < input_iters.size(); ++i) {
57+
ss << "\n input " << i << ": " << PrintFusionIters(input_iters[i]);
58+
}
59+
for (size_t i = 0; i < output_iters.size(); ++i) {
60+
ss << "\n output " << i << ": " << PrintFusionIters(output_iters[i]);
61+
}
62+
return ss.str();
63+
}
64+
65+
FusionItersSignature SingleDownstreamItersFusion(PatternNodePtr upstream,
66+
PatternNodePtr downstream,
67+
bool is_sink) {
68+
VLOG(4) << "[ItersFusion] Start SingleDownstreamItersFusion.";
69+
auto upstream_iters = upstream->fusion_iters();
70+
auto downstream_iters = downstream->fusion_iters();
71+
PADDLE_ENFORCE_EQ(upstream_iters.output_iters.size(),
72+
1,
73+
::common::errors::InvalidArgument(
74+
"The number of upstream outputs should be 1."));
75+
76+
FusionItersSignature fused_iters;
77+
fused_iters.loop_iters =
78+
is_sink ? downstream_iters.loop_iters : upstream_iters.loop_iters;
79+
fused_iters.output_values = downstream_iters.output_values;
80+
fused_iters.output_iters = downstream_iters.output_iters;
81+
82+
const auto& upstream_output_value = upstream_iters.output_values[0];
83+
for (size_t i = 0; i < downstream_iters.input_values.size(); ++i) {
84+
if (downstream_iters.input_values[i] == upstream_output_value) {
85+
for (size_t j = 0; j < upstream_iters.input_iters.size(); ++j) {
86+
fused_iters.input_iters.push_back(upstream_iters.input_iters[j]);
87+
fused_iters.input_values.push_back(upstream_iters.input_values[j]);
88+
}
89+
} else {
90+
fused_iters.input_iters.push_back(downstream_iters.input_iters[i]);
91+
fused_iters.input_values.push_back(downstream_iters.input_values[i]);
92+
}
93+
}
94+
VLOG(4) << "[ItersFusion] End SingleDownstreamItersFusion.";
95+
return fused_iters;
96+
}
97+
98+
} // namespace cinn::fusion
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/cinn/operator_fusion/pir_graph_analyzing/shardable_axes_base.h"
18+
19+
namespace cinn::fusion {
20+
21+
using FusionIters = std::vector<std::string>;
22+
struct FusionItersSignature {
23+
FusionItersSignature() = default;
24+
FusionItersSignature(pir::Operation* op, const ShardableAxesSignature& axes);
25+
std::string DebugStr() const;
26+
27+
FusionIters loop_iters = {};
28+
std::vector<FusionIters> input_iters = {};
29+
std::vector<FusionIters> output_iters = {};
30+
std::vector<pir::Value> input_values = {};
31+
std::vector<pir::Value> output_values = {};
32+
};
33+
34+
class PatternNode;
35+
using PatternNodePtr = std::shared_ptr<PatternNode>;
36+
FusionItersSignature SingleDownstreamItersFusion(PatternNodePtr upstream,
37+
PatternNodePtr downstream,
38+
bool is_sink);
39+
40+
} // namespace cinn::fusion

0 commit comments

Comments
 (0)