Skip to content

Commit f162e99

Browse files
authored
[CINN] Update compute at dependency graph (#71111)
1 parent 3967f44 commit f162e99

File tree

10 files changed

+111
-15
lines changed

10 files changed

+111
-15
lines changed

paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.cc

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
// limitations under the License.
1414

1515
#include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
16+
#include "paddle/cinn/ir/ir_analyzer/data_dependency_graph.h"
1617
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
1718
#include "paddle/cinn/ir/utils/ir_compare.h"
19+
#include "paddle/cinn/ir/utils/stmt_converter.h"
1820
#include "paddle/cinn/optim/replace_var_with_expr.h"
1921

2022
namespace cinn {
@@ -104,7 +106,6 @@ class ComputeAtReductionTactic final : public ScheduleTactic {
104106
// A copy of the IRSchedule and ScheduleBlockGraph, with all loop vars
105107
// unifiedly rewritten to the form `$<loop_index>` (e.g. $0, $1).
106108
std::unique_ptr<ir::IRSchedule> sch_;
107-
std::unique_ptr<ir::ScheduleBlockGraph> graph_;
108109

109110
// Cache of the results of GetUnifiedControlFlow, GetLoopVariantLoads,
110111
// GetSerialLoopExtent and IsReductionSBlock because these functions are too
@@ -241,7 +242,6 @@ void ComputeAtReductionTactic::Init(ScheduleContext* context,
241242
serial_loop_extent_.clear();
242243

243244
sch_ = std::make_unique<ir::IRSchedule>(*sch);
244-
graph_ = std::make_unique<ir::ScheduleBlockGraph>(*sch_);
245245

246246
for (auto& block : sch_->GetAllBlocks()) {
247247
// Replace loop_vars to the unified form `$<loop_index>`
@@ -373,10 +373,8 @@ std::vector<std::string>
373373
ComputeAtReductionTactic::GetDependencyHarzardFreeBlocks(
374374
ir::IRSchedule* sch, const std::string& block_id) {
375375
std::vector<std::string> results;
376-
std::vector<ir::Expr> blocks = sch->GetAllBlocks();
377-
auto* graph_node = graph_->RetrieveNode(block_id);
378-
std::unordered_set<std::string> upstreams = graph_node->UpstreamNodes();
379-
std::unordered_set<std::string> downstreams = graph_node->DownstreamNodes();
376+
std::vector<stmt::StmtRef> stmts = sch->GetAllSchedules();
377+
analyzer::DataDependencyGraph dep_graph(stmts);
380378

381379
// Find the position of the current block in the graph, then search upwards
382380
// and downwards until a denepency harzard is met.
@@ -397,25 +395,25 @@ ComputeAtReductionTactic::GetDependencyHarzardFreeBlocks(
397395
// C has denepency harzard with B because it directly depends on B. C also has
398396
// dependency harzard with A, because if we move C to the position of A, we
399397
// will violate the dependency of B->C. C is only harzard-free with D and E.
400-
auto this_it =
401-
std::find_if(blocks.begin(), blocks.end(), [&](const ir::Expr& block) {
402-
return analyzer::GetBlockName(block) == block_id;
398+
auto this_it = std::find_if(
399+
stmts.begin(), stmts.end(), [&](const ir::stmt::StmtRef& stmt) {
400+
return stmt.as<stmt::Schedule>()->name() == block_id;
403401
});
404402

405403
// Search upwards
406404
auto this_it_rev = std::make_reverse_iterator(this_it);
407-
for (auto it = this_it_rev; it != blocks.rend(); ++it) {
408-
std::string other_id = analyzer::GetBlockName(*it);
405+
for (auto it = this_it_rev; it != stmts.rend(); ++it) {
406+
std::string other_id = (*it).as<stmt::Schedule>()->name();
409407
// As a special case, we can ignore the `reduce_init` in front of Reduce.
410408
if (IsReduceInitTensorName(other_id)) continue;
411-
if (upstreams.count(other_id) > 0) break;
409+
if (dep_graph.HasDependency(*it, *this_it) == analyzer::DepKind::DEP) break;
412410
results.push_back(other_id);
413411
}
414412

415413
// Search downwards
416-
for (auto it = this_it + 1; it != blocks.end(); ++it) {
417-
std::string other_id = analyzer::GetBlockName(*it);
418-
if (downstreams.count(other_id) > 0) break;
414+
for (auto it = this_it + 1; it != stmts.end(); ++it) {
415+
std::string other_id = (*it).as<stmt::Schedule>()->name();
416+
if (dep_graph.HasDependency(*this_it, *it) == analyzer::DepKind::DEP) break;
419417
results.push_back(other_id);
420418
}
421419

paddle/cinn/ir/ir_analyzer/ir_analyzer.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "paddle/cinn/ir/schedule/schedule_desc.h"
3535
#include "paddle/cinn/ir/tensor.h"
3636
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
37+
#include "paddle/cinn/ir/utils/stmt_converter.h"
3738
#include "paddle/cinn/utils/random_engine.h"
3839
#include "paddle/common/enforce.h"
3940
#include "paddle/fluid/platform/enforce.h"
@@ -110,6 +111,27 @@ std::vector<Expr> GetAllBlocks(const std::vector<Expr>& exprs) {
110111
return result;
111112
}
112113

114+
std::vector<stmt::StmtRef> GetAllSchedules(const std::vector<Expr>& exprs) {
115+
std::vector<stmt::StmtRef> result;
116+
for (auto& it_expr : exprs) {
117+
stmt::BlockRef stmt_block;
118+
if (it_expr.As<ir::Block>()) {
119+
stmt_block = ConvertExprBlockToStmtBlock(it_expr);
120+
} else {
121+
stmt_block = ConvertExprBlockToStmtBlock(
122+
ir::Block::Make(std::vector<ir::Expr>{it_expr}));
123+
}
124+
FindSchedulesVisitor visitor;
125+
auto find_blocks = visitor(stmt_block);
126+
result.insert(result.end(), find_blocks.begin(), find_blocks.end());
127+
}
128+
PADDLE_ENFORCE_EQ(
129+
result.empty(),
130+
false,
131+
::common::errors::InvalidArgument("Didn't find schedules in expr."));
132+
return result;
133+
}
134+
113135
std::vector<Expr> GetChildBlocks(const Expr& expr) {
114136
if (!expr.As<ir::ScheduleBlockRealize>()) {
115137
PADDLE_ENFORCE_NOT_NULL(expr.As<ir::For>(),

paddle/cinn/ir/ir_analyzer/ir_analyzer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ std::vector<Expr> GetLoops(const std::vector<Expr>& exprs, const Expr& block);
3535

3636
std::vector<Expr> GetAllBlocks(const std::vector<Expr>& exprs);
3737

38+
std::vector<stmt::StmtRef> GetAllSchedules(const std::vector<Expr>& exprs);
39+
3840
std::vector<Expr> GetChildBlocks(const Expr& expr);
3941

4042
Expr GetBlock(const std::vector<Expr>& exprs, const std::string& block_name);

paddle/cinn/ir/schedule/impl/base.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ std::vector<Expr> DyScheduleImpl::GetAllBlocks() const {
167167
CINN_IR_SCHEDULE_END(this->err_msg_level_);
168168
}
169169

170+
std::vector<stmt::StmtRef> DyScheduleImpl::GetAllSchedules() const {
171+
CINN_IR_SCHEDULE_BEGIN();
172+
std::string primitive = "GetAllSchedules";
173+
std::ostringstream os;
174+
auto exprs = module_expr_.GetExprs();
175+
return analyzer::GetAllSchedules(exprs);
176+
CINN_IR_SCHEDULE_END(this->err_msg_level_);
177+
}
178+
170179
std::vector<Expr> DyScheduleImpl::GetChildBlocks(const Expr& expr) const {
171180
CINN_IR_SCHEDULE_BEGIN();
172181
std::string primitive = "GetChildBlocks";

paddle/cinn/ir/schedule/impl/ir_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class DyScheduleImpl : public ScheduleBase {
4646
std::vector<Expr> GetLoops(const Expr& block) const;
4747
std::vector<Expr> GetLoops(const std::string& block_name) const;
4848
std::vector<Expr> GetAllBlocks() const;
49+
std::vector<stmt::StmtRef> GetAllSchedules() const;
4950
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
5051
Expr GetBlock(const std::string& block_name) const;
5152
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);

paddle/cinn/ir/schedule/ir_schedule.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,12 @@ std::vector<Expr> IRSchedule::GetAllBlocks() const {
354354
return results;
355355
}
356356

357+
std::vector<stmt::StmtRef> IRSchedule::GetAllSchedules() const {
358+
auto results = impl_->GetAllSchedules();
359+
trace_.Append(ScheduleDesc::Step("GetAllSchedules", {}, {}, {}, results));
360+
return results;
361+
}
362+
357363
std::vector<Expr> IRSchedule::GetChildBlocks(const Expr& expr) const {
358364
auto results = impl_->GetChildBlocks(expr);
359365
trace_.Append(ScheduleDesc::Step(

paddle/cinn/ir/schedule/ir_schedule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class IRSchedule {
9090
//! Get all blocks stored in this ModuleExpr.
9191
std::vector<Expr> GetAllBlocks() const;
9292

93+
//! Get all schedules stored in this ModuleExpr.
94+
std::vector<stmt::StmtRef> GetAllSchedules() const;
95+
9396
//! Get a block with the specific name.
9497
Expr GetBlock(const std::string& block_name) const;
9598

paddle/cinn/ir/schedule/ir_schedule_util.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,48 @@ struct FindBlocksVisitor {
11801180
std::vector<Expr> result{};
11811181
};
11821182

1183+
struct FindSchedulesVisitor {
1184+
explicit FindSchedulesVisitor(const std::string& schedule_name = "")
1185+
: schedule_name_(schedule_name) {}
1186+
1187+
std::vector<stmt::StmtRef> operator()(const stmt::BlockRef& block) {
1188+
VisitBlock(block);
1189+
return result;
1190+
}
1191+
1192+
private:
1193+
void VisitStmt(const stmt::StmtRef& stmt) {
1194+
if (!stmt.defined()) return;
1195+
if (!schedule_name_.empty() && !result.empty()) return;
1196+
if (stmt.isa<stmt::For>()) {
1197+
VisitBlock(stmt.as<stmt::For>()->body());
1198+
} else if (stmt.isa<stmt::Schedule>()) {
1199+
if (stmt.as<stmt::Schedule>()->name().substr(0, 4) != "root") {
1200+
auto schedule_block = stmt.as<stmt::Schedule>();
1201+
if (schedule_name_.empty() ||
1202+
schedule_block->name() == schedule_name_) {
1203+
result.emplace_back(stmt);
1204+
}
1205+
} else {
1206+
VisitBlock(stmt.as<stmt::Schedule>()->body());
1207+
}
1208+
} else if (stmt.isa<stmt::IfThenElse>()) {
1209+
VisitBlock(stmt.as<stmt::IfThenElse>()->true_case());
1210+
if (stmt.as<stmt::IfThenElse>()->false_case().defined())
1211+
VisitBlock(stmt.as<stmt::IfThenElse>()->false_case());
1212+
}
1213+
}
1214+
1215+
void VisitBlock(const stmt::BlockRef& block) {
1216+
for (const stmt::StmtRef& inner_stmt : block->stmts()) {
1217+
VisitStmt(inner_stmt);
1218+
}
1219+
}
1220+
1221+
std::string schedule_name_;
1222+
std::vector<stmt::StmtRef> result{};
1223+
};
1224+
11831225
struct FindLoopsVisitor {
11841226
explicit FindLoopsVisitor(const Expr& block) : block_(block) {}
11851227

paddle/cinn/ir/schedule/schedule_base.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class ScheduleBase {
108108
virtual std::vector<Expr> GetLoops(const Expr& block) const = 0;
109109
virtual std::vector<Expr> GetLoops(const std::string& block_name) const = 0;
110110
virtual std::vector<Expr> GetAllBlocks() const = 0;
111+
virtual std::vector<stmt::StmtRef> GetAllSchedules() const = 0;
111112
virtual std::vector<Expr> GetChildBlocks(const Expr& expr) const = 0;
112113
virtual Expr GetBlock(const std::string& block_name) const = 0;
113114

paddle/cinn/ir/schedule/schedule_desc.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include "paddle/cinn/ir/ir.h"
2323
#include "paddle/cinn/ir/schedule/schedule_desc.pb.h"
24+
#include "paddle/cinn/ir/stmt.h"
2425
#include "paddle/cinn/utils/registry.h"
2526
#include "paddle/cinn/utils/type_defs.h"
2627

@@ -43,12 +44,23 @@ class ScheduleDesc {
4344
absl::flat_hash_map<std::string, std::vector<Expr>> inputs;
4445
utils::AttributeMap attrs;
4546
std::vector<Expr> outputs;
47+
std::vector<stmt::StmtRef> stmt_outputs;
4648
Step() = default;
4749
Step(std::string type_i,
4850
absl::flat_hash_map<std::string, std::vector<Expr>> inputs_i,
4951
utils::AttributeMap attrs_i,
5052
std::vector<Expr> outputs_i)
5153
: type(type_i), inputs(inputs_i), attrs(attrs_i), outputs(outputs_i) {}
54+
Step(std::string type_i,
55+
absl::flat_hash_map<std::string, std::vector<Expr>> inputs_i,
56+
utils::AttributeMap attrs_i,
57+
std::vector<Expr> outputs_i,
58+
std::vector<stmt::StmtRef> stmt_outputs_i)
59+
: type(type_i),
60+
inputs(inputs_i),
61+
attrs(attrs_i),
62+
outputs(outputs_i),
63+
stmt_outputs(stmt_outputs_i) {}
5264
};
5365

5466
/**

0 commit comments

Comments
 (0)