Skip to content

Commit 0c34b1f

Browse files
authored
[CINN] Fix ReduceTree Fusion (#67400)
* update * update * update * update
1 parent c2d2b66 commit 0c34b1f

File tree

8 files changed

+108
-85
lines changed

8 files changed

+108
-85
lines changed

paddle/cinn/hlir/framework/pir/trivial_op_impl.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ std::vector<ir::Var> GetOutputIters(const FusibleOp& op) {
181181
}
182182
};
183183
VLOG(4) << "GetOutputIters";
184-
VLOG(4) << "Before AppendBound:" << _GetRootExpr(op);
185184
return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op));
186185
}
187186

paddle/cinn/operator_fusion/fusion_tracker/expr_utils.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@
1717

1818
namespace cinn::fusion {
1919

20+
std::vector<ir::Expr> GetFusibleOpsExpr(std::vector<FusibleOp> fusion_ops) {
21+
std::vector<ir::Expr> exprs;
22+
for (auto& fusion_op : fusion_ops) {
23+
auto expr = std::visit(FusibleOp2Expr(), fusion_op).front();
24+
exprs.push_back(expr);
25+
}
26+
return exprs;
27+
}
28+
2029
// tmp transform for reduce_tree and reduce_tree_trivial.
2130
std::vector<ir::Tensor> GetOutputTensors(const ir::Expr& op_expr) {
2231
using cinn::hlir::framework::pir::trivial_fusion_detail::ExprSetFinderUtils::

paddle/cinn/operator_fusion/fusion_tracker/expr_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ struct ApplyTransform {
6060
ir::Expr expr_;
6161
};
6262

63+
std::vector<ir::Expr> GetFusibleOpsExpr(std::vector<FusibleOp> fusion_ops);
6364
std::vector<ir::Expr> TopoSort(const std::vector<ir::Expr>& op_exprs);
6465
std::vector<FusibleOp> DoPadding(const FusibleOp& fusion_op,
6566
const std::vector<int>& padding_pos);

paddle/cinn/operator_fusion/fusion_tracker/interpreter.cc

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ void RunCombineInstr(const std::shared_ptr<CombineInstr>& instr,
3333
const auto& to_insert = interpreter->scope.at(name);
3434
new_pattern->Extend(to_insert->fusion_ops);
3535
}
36+
VLOG(4) << "After CombineInstr Pattern: \n"
37+
<< GetFusibleOpsExpr(new_pattern->fusion_ops);
3638
interpreter->scope[instr->result_] = new_pattern;
3739
}
3840

@@ -69,13 +71,22 @@ void RunTrivialInlineInstr(const std::shared_ptr<TrivialInlineInstr>& instr,
6971

7072
void RunTmpTransformInstr(const std::shared_ptr<TmpTransformInstr>& instr,
7173
FusionInterpreter* interpreter) {
72-
VLOG(4) << interpreter->scope[instr->upstream_]->fusion_ops.size();
73-
PADDLE_ENFORCE_EQ(interpreter->scope[instr->downstream_]->fusion_ops.size(),
74-
1,
75-
::common::errors::InvalidArgument(
76-
"Downstream op must have only one fusion_op."));
74+
PADDLE_ENFORCE_GT(
75+
interpreter->scope.count(instr->upstream_),
76+
0,
77+
::common::errors::NotFound("Can not find TmpTransformInstr uptream."));
78+
PADDLE_ENFORCE_GT(
79+
interpreter->scope.count(instr->downstream_),
80+
0,
81+
::common::errors::NotFound("Can not find TmpTransformInstr downstream."));
82+
83+
PADDLE_ENFORCE_EQ(
84+
interpreter->scope[instr->downstream_]->fusion_ops.size(),
85+
1,
86+
::common::errors::InvalidArgument(
87+
"Downstream %s must have only one fusion_op.", instr->downstream_));
7788
auto upstream_op = std::get<ReduceOp>(
78-
interpreter->scope[instr->upstream_]->fusion_ops.front());
89+
interpreter->scope[instr->upstream_]->fusion_ops.back());
7990
auto downstream_op =
8091
interpreter->scope[instr->downstream_]->fusion_ops.front();
8192
// inplace set the upstream

paddle/cinn/operator_fusion/pattern.h

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ struct ReduceTreePattern {
8888
const FusionTrackerPtr& tracker)
8989
: childs_(childs), root_(root), tracker_(tracker) {
9090
id_ = UniqueId();
91+
cur_id_ = id_;
9192
}
9293
const ReducePattern& GetRootPattern() const { return root_; }
9394
std::vector<pir::Operation*> ops() const {
@@ -118,44 +119,48 @@ struct ReduceTreePattern {
118119
std::string id() const { return id_; }
119120
std::string id_;
120121

122+
mutable std::string cur_id_;
123+
std::string cur_id() const { return cur_id_; }
124+
std::string new_tmp_id() const {
125+
if (cur_id_ == id_) {
126+
cur_id_ = id_ + "_tmp_0";
127+
} else {
128+
int ith = std::stoi(cur_id_.substr(cur_id_.size() - 1));
129+
cur_id_ = id_ + "_tmp_" + std::to_string(ith + 1);
130+
}
131+
return cur_id_;
132+
}
133+
121134
FusionTrackerPtr tracker_;
122135

123136
void update_tracker() const {
124-
int counter = 0;
125-
std::function<std::string()> gen_name = [&counter]() {
126-
return "tmp_" + std::to_string(counter++);
127-
};
128-
const std::string& root_name = id();
137+
const std::string& root_name = GetRootPattern().id();
129138
std::vector<std::string> names;
130-
UpdateTrackerImpl(root_name,
131-
*this,
132-
std::vector<size_t>(),
133-
gen_name,
134-
this->tracker_,
135-
&names);
136-
tracker_->append(std::make_shared<CombineInstr>(names, root_name));
139+
UpdateTrackerImpl(
140+
root_name, *this, std::vector<size_t>(), this->tracker_, &names);
141+
tracker_->append(std::make_shared<CombineInstr>(names, cur_id()));
137142
}
138143

139144
void UpdateTrackerImpl(const std::string root_name,
140145
const ReduceTreePattern& root,
141146
const std::vector<size_t>& fake_reduce_iter_idx,
142-
const std::function<std::string()>& unique_tmp_name_fn,
143147
FusionTrackerPtr tracker,
144148
std::vector<std::string>* names) const {
145-
// Apply a bunch of tracker to get a output_name of ReduceTreePattern.
149+
// Apply a brunch of tracker to get a output_name of ReduceTreePattern.
146150
// names and trackers collect all the needed fusion nodes.
147-
for (const auto& child : childs_) {
148-
const std::string& tmp_name = unique_tmp_name_fn();
149-
tracker->append(std::make_shared<TmpTransformInstr>(
150-
tmp_name, root_name, tmp_name, root_name, fake_reduce_iter_idx));
151-
UpdateTrackerImpl(tmp_name,
152-
child,
153-
fake_reduce_iter_idx,
154-
unique_tmp_name_fn,
155-
tracker,
156-
names);
151+
for (const auto& child : root.childs()) {
152+
auto origin_child_id = child.cur_id();
153+
auto new_child_id = child.new_tmp_id();
154+
tracker->append(
155+
std::make_shared<TmpTransformInstr>(origin_child_id,
156+
root_name,
157+
new_child_id,
158+
root.cur_id(),
159+
fake_reduce_iter_idx));
160+
UpdateTrackerImpl(
161+
new_child_id, child, fake_reduce_iter_idx, tracker, names);
157162
}
158-
names->push_back(root_name);
163+
names->push_back(root.cur_id());
159164
}
160165

161166
private:
@@ -190,29 +195,21 @@ struct ReduceTreePlusTrivialPattern {
190195
FusionTrackerPtr tracker_;
191196

192197
void update_tracker() const {
193-
int counter = 0;
194-
std::function<std::string()> gen_name = [&counter]() {
195-
return "tmp_" + std::to_string(counter++);
196-
};
197198
const std::string& root_name = id();
198-
const std::string& tmp_name_for_tree = gen_name();
199+
const std::string& origin_tree_id = tree.cur_id();
200+
const std::string& new_tree_id = tree.new_tmp_id();
199201
std::vector<std::string> names;
200-
tracker_->append(
201-
std::make_shared<TmpTransformInstr>(tree.GetRootPattern().id(),
202-
sink_trivial.id(),
203-
tmp_name_for_tree,
204-
root_name,
205-
fake_reduce_iter_idx));
206-
tree.UpdateTrackerImpl(tmp_name_for_tree,
207-
tree,
208-
fake_reduce_iter_idx,
209-
gen_name,
210-
this->tracker_,
211-
&names);
202+
tracker_->append(std::make_shared<TmpTransformInstr>(origin_tree_id,
203+
sink_trivial.id(),
204+
new_tree_id,
205+
root_name,
206+
fake_reduce_iter_idx));
207+
tree.UpdateTrackerImpl(
208+
new_tree_id, tree, fake_reduce_iter_idx, this->tracker_, &names);
212209
names.push_back(root_name);
213210
// optimize the loop range of R + T for speed up.
214211
tracker_->append(std::make_shared<TrivialLoopAlignInstr>(
215-
tmp_name_for_tree, root_name, root_name, fake_reduce_iter_idx));
212+
new_tree_id, root_name, root_name, fake_reduce_iter_idx));
216213
// collect all the Expr and represent the root_name.
217214
tracker_->append(std::make_shared<CombineInstr>(names, root_name));
218215
}

paddle/cinn/operator_fusion/pattern_fuser.h

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,23 @@ static StmtPattern MergePatternImpl(const TrivialPattern& first,
132132

133133
// RR & RT
134134

135-
static int InsertDownstreamIntoTree(const ReduceTreePattern& upstream,
136-
ReduceTreePattern& downstream) { // NOLINT
137-
if (IsDirectUpstream(upstream.GetRootPattern().GetReduceOp(),
138-
downstream.GetRootPattern().GetReduceOp())) {
135+
static int InsertUpstreamIntoTree(const ReduceTreePattern& upstream,
136+
ReduceTreePattern& downstream) { // NOLINT
137+
auto is_direct_upstream = [&](const ReducePattern& upstream,
138+
const ReducePattern& downstream) -> bool {
139+
auto upstream_result = upstream.GetReduceOp()->result(0);
140+
auto user_ops = FindUserOp(downstream.ops(), upstream_result);
141+
return !user_ops.empty();
142+
};
143+
144+
if (is_direct_upstream(upstream.GetRootPattern(),
145+
downstream.GetRootPattern())) {
139146
downstream.InsertChild(upstream);
140147
return 1;
141148
}
142149
int insert_num = 0;
143150
for (auto& child : downstream.childs()) {
144-
insert_num += InsertDownstreamIntoTree(upstream, child);
151+
insert_num += InsertUpstreamIntoTree(upstream, child);
145152
}
146153
return insert_num;
147154
}
@@ -153,7 +160,7 @@ static StmtPattern MergePatternImpl(const ReduceTreePattern& upstream,
153160
downstream.GetRootPattern(),
154161
std::make_shared<FusionTracker>(upstream.tracker_,
155162
downstream.tracker_)); // copy first.
156-
int insert_num = InsertDownstreamIntoTree(upstream, result);
163+
int insert_num = InsertUpstreamIntoTree(upstream, result);
157164
PADDLE_ENFORCE_EQ(insert_num,
158165
1,
159166
phi::errors::PreconditionNotMet(

paddle/cinn/operator_fusion/policy/relative_judge_policy.cc

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -174,37 +174,17 @@ bool IsProductSmallerOrEqual(const std::vector<DimUsage>& first,
174174
return shape_analysis.IsEqual(first_product, second_product);
175175
}
176176

177-
pir::Operation* FindUserOp(const std::vector<pir::Operation*>& candidates,
178-
const pir::Value& value) {
179-
std::vector<pir::Operation*> results;
180-
for (auto consumer_it = value.use_begin(); consumer_it != value.use_end();
181-
++consumer_it) {
182-
pir::Operation* user_op = consumer_it.owner();
183-
auto iter = std::find(candidates.begin(), candidates.end(), user_op);
184-
if (iter != candidates.end()) {
185-
results.emplace_back(*iter);
186-
}
187-
}
188-
PADDLE_ENFORCE_EQ(results.size(),
189-
1,
190-
::common::errors::InvalidArgument(
191-
"Zero or multiple user operations found in candidates! "
192-
"Expected exactly one, but found %d.",
193-
results.size()));
194-
return results.front();
195-
}
196-
197177
bool RelativeJudgePolicy::ReduceTreeGrownCanMerge(
198178
const PatternNodePtr& upstream, const PatternNodePtr& downstream) {
199179
const auto& upstream_tree =
200180
std::get<ReduceTreePattern>(upstream->stmt_pattern());
201181
const auto& downstream_tree =
202182
std::get<ReduceTreePattern>(downstream->stmt_pattern());
203183

204-
VLOG(4) << "upstream->stmt_pattern():"
205-
<< OpsDebugStr(GetOpsInPattern(upstream_tree));
206-
VLOG(4) << "downstream->stmt_pattern()"
207-
<< OpsDebugStr(GetOpsInPattern(downstream_tree));
184+
VLOG(4) << "upstream: \n" << OpsDebugStr(GetOpsInPattern(upstream_tree));
185+
VLOG(4) << "upstream->childs_num: " << upstream_tree.childs().size();
186+
VLOG(4) << "downstream: \n" << OpsDebugStr(GetOpsInPattern(downstream_tree));
187+
VLOG(4) << "downstream->childs_num: " << downstream_tree.childs().size();
208188

209189
const auto& maybe_downstream_op = GetDownstreamFromCandidate(
210190
upstream_tree.GetRootPattern(), downstream_tree.FlattenReducePattern());
@@ -220,7 +200,7 @@ bool RelativeJudgePolicy::ReduceTreeGrownCanMerge(
220200
}
221201
const pir::Value& reduce_out_value =
222202
upstream_tree.GetRootPattern().GetReduceOp()->result(0);
223-
auto downstream_connect_op =
203+
auto downstream_connect_ops =
224204
FindUserOp(downstream_tree.ops(), reduce_out_value);
225205
pir::Operation* downstream_reduce_op =
226206
maybe_downstream_op.value().GetReduceOp();
@@ -229,8 +209,13 @@ bool RelativeJudgePolicy::ReduceTreeGrownCanMerge(
229209
SplitReduceDims(axes_info_.GetSignature(downstream_reduce_op),
230210
downstream_reduce_op);
231211

232-
const auto& upstream_output_dims = GetValueUsage(
233-
reduce_out_value, GetUsageIdx(reduce_out_value, downstream_connect_op));
212+
std::vector<DimUsage> upstream_output_dims;
213+
for (const auto& op : downstream_connect_ops) {
214+
auto dim_usages =
215+
GetValueUsage(reduce_out_value, GetUsageIdx(reduce_out_value, op));
216+
upstream_output_dims.insert(
217+
upstream_output_dims.end(), dim_usages.begin(), dim_usages.end());
218+
}
234219
const auto& [related, _UNUSED] =
235220
SplitFirstIfRelatedBySecond(downstream_reduce_dims, upstream_output_dims);
236221
auto res = (related.size() == 0);

paddle/cinn/operator_fusion/utils.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,24 @@ static const size_t GetResultIdx(const pir::Value& v, pir::Operation* op) {
453453
"Can not find the value %s as result of op %s", v.impl(), op->name()));
454454
}
455455

456+
static std::vector<pir::Operation*> FindUserOp(
457+
const std::vector<pir::Operation*>& candidates, const pir::Value& value) {
458+
std::vector<pir::Operation*> results;
459+
for (auto consumer_it = value.use_begin(); consumer_it != value.use_end();
460+
++consumer_it) {
461+
pir::Operation* user_op = consumer_it.owner();
462+
auto iter = std::find(candidates.begin(), candidates.end(), user_op);
463+
if (iter != candidates.end()) {
464+
results.emplace_back(*iter);
465+
}
466+
}
467+
return results;
468+
}
469+
456470
static bool IsDirectUpstream(const pir::Operation* upstream,
457471
const pir::Operation* downstream) {
458-
for (const auto& value : downstream->results()) {
459-
for (const auto& operand : upstream->operands()) {
472+
for (const auto& value : upstream->results()) {
473+
for (const auto& operand : downstream->operands()) {
460474
if (value == operand.source()) {
461475
return true;
462476
}

0 commit comments

Comments
 (0)