Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
// limitations under the License.

#include "paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
#include "paddle/cinn/ir/ir_analyzer/data_dependency_graph.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/utils/ir_compare.h"
#include "paddle/cinn/ir/utils/stmt_converter.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"

namespace cinn {
Expand Down Expand Up @@ -104,7 +106,6 @@ class ComputeAtReductionTactic final : public ScheduleTactic {
// A copy of the IRSchedule and ScheduleBlockGraph, with all loop vars
// unifiedly rewritten to the form `$<loop_index>` (e.g. $0, $1).
std::unique_ptr<ir::IRSchedule> sch_;
std::unique_ptr<ir::ScheduleBlockGraph> graph_;

// Cache of the results of GetUnifiedControlFlow, GetLoopVariantLoads,
// GetSerialLoopExtent and IsReductionSBlock because these functions are too
Expand Down Expand Up @@ -241,7 +242,6 @@ void ComputeAtReductionTactic::Init(ScheduleContext* context,
serial_loop_extent_.clear();

sch_ = std::make_unique<ir::IRSchedule>(*sch);
graph_ = std::make_unique<ir::ScheduleBlockGraph>(*sch_);

for (auto& block : sch_->GetAllBlocks()) {
// Replace loop_vars to the unified form `$<loop_index>`
Expand Down Expand Up @@ -373,10 +373,8 @@ std::vector<std::string>
ComputeAtReductionTactic::GetDependencyHarzardFreeBlocks(
ir::IRSchedule* sch, const std::string& block_id) {
std::vector<std::string> results;
std::vector<ir::Expr> blocks = sch->GetAllBlocks();
auto* graph_node = graph_->RetrieveNode(block_id);
std::unordered_set<std::string> upstreams = graph_node->UpstreamNodes();
std::unordered_set<std::string> downstreams = graph_node->DownstreamNodes();
std::vector<stmt::StmtRef> stmts = sch->GetAllSchedules();
analyzer::DataDependencyGraph dep_graph(stmts);

// Find the position of the current block in the graph, then search upwards
// and downwards until a denepency harzard is met.
Expand All @@ -397,25 +395,25 @@ ComputeAtReductionTactic::GetDependencyHarzardFreeBlocks(
// C has denepency harzard with B because it directly depends on B. C also has
// dependency harzard with A, because if we move C to the position of A, we
// will violate the dependency of B->C. C is only harzard-free with D and E.
auto this_it =
std::find_if(blocks.begin(), blocks.end(), [&](const ir::Expr& block) {
return analyzer::GetBlockName(block) == block_id;
auto this_it = std::find_if(
stmts.begin(), stmts.end(), [&](const ir::stmt::StmtRef& stmt) {
return stmt.as<stmt::Schedule>()->name() == block_id;
});

// Search upwards
auto this_it_rev = std::make_reverse_iterator(this_it);
for (auto it = this_it_rev; it != blocks.rend(); ++it) {
std::string other_id = analyzer::GetBlockName(*it);
for (auto it = this_it_rev; it != stmts.rend(); ++it) {
std::string other_id = (*it).as<stmt::Schedule>()->name();
// As a special case, we can ignore the `reduce_init` in front of Reduce.
if (IsReduceInitTensorName(other_id)) continue;
if (upstreams.count(other_id) > 0) break;
if (dep_graph.HasDependency(*it, *this_it) == analyzer::DepKind::DEP) break;
results.push_back(other_id);
}

// Search downwards
for (auto it = this_it + 1; it != blocks.end(); ++it) {
std::string other_id = analyzer::GetBlockName(*it);
if (downstreams.count(other_id) > 0) break;
for (auto it = this_it + 1; it != stmts.end(); ++it) {
std::string other_id = (*it).as<stmt::Schedule>()->name();
if (dep_graph.HasDependency(*this_it, *it) == analyzer::DepKind::DEP) break;
results.push_back(other_id);
}

Expand Down
22 changes: 22 additions & 0 deletions paddle/cinn/ir/ir_analyzer/ir_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/stmt_converter.h"
#include "paddle/cinn/utils/random_engine.h"
#include "paddle/common/enforce.h"
#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -110,6 +111,27 @@ std::vector<Expr> GetAllBlocks(const std::vector<Expr>& exprs) {
return result;
}

std::vector<stmt::StmtRef> GetAllSchedules(const std::vector<Expr>& exprs) {
std::vector<stmt::StmtRef> result;
for (auto& it_expr : exprs) {
stmt::BlockRef stmt_block;
if (it_expr.As<ir::Block>()) {
stmt_block = ConvertExprBlockToStmtBlock(it_expr);
} else {
stmt_block = ConvertExprBlockToStmtBlock(
ir::Block::Make(std::vector<ir::Expr>{it_expr}));
}
FindSchedulesVisitor visitor;
auto find_blocks = visitor(stmt_block);
result.insert(result.end(), find_blocks.begin(), find_blocks.end());
}
PADDLE_ENFORCE_EQ(
result.empty(),
false,
::common::errors::InvalidArgument("Didn't find schedules in expr."));
return result;
}

std::vector<Expr> GetChildBlocks(const Expr& expr) {
if (!expr.As<ir::ScheduleBlockRealize>()) {
PADDLE_ENFORCE_NOT_NULL(expr.As<ir::For>(),
Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/ir/ir_analyzer/ir_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ std::vector<Expr> GetLoops(const std::vector<Expr>& exprs, const Expr& block);

std::vector<Expr> GetAllBlocks(const std::vector<Expr>& exprs);

std::vector<stmt::StmtRef> GetAllSchedules(const std::vector<Expr>& exprs);

std::vector<Expr> GetChildBlocks(const Expr& expr);

Expr GetBlock(const std::vector<Expr>& exprs, const std::string& block_name);
Expand Down
9 changes: 9 additions & 0 deletions paddle/cinn/ir/schedule/impl/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ std::vector<Expr> DyScheduleImpl::GetAllBlocks() const {
CINN_IR_SCHEDULE_END(this->err_msg_level_);
}

std::vector<stmt::StmtRef> DyScheduleImpl::GetAllSchedules() const {
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "GetAllSchedules";
std::ostringstream os;
auto exprs = module_expr_.GetExprs();
return analyzer::GetAllSchedules(exprs);
CINN_IR_SCHEDULE_END(this->err_msg_level_);
}

std::vector<Expr> DyScheduleImpl::GetChildBlocks(const Expr& expr) const {
CINN_IR_SCHEDULE_BEGIN();
std::string primitive = "GetChildBlocks";
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/schedule/impl/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class DyScheduleImpl : public ScheduleBase {
std::vector<Expr> GetLoops(const Expr& block) const;
std::vector<Expr> GetLoops(const std::string& block_name) const;
std::vector<Expr> GetAllBlocks() const;
std::vector<stmt::StmtRef> GetAllSchedules() const;
std::vector<Expr> GetChildBlocks(const Expr& expr) const;
Expr GetBlock(const std::string& block_name) const;
std::vector<Expr> Split(const Expr& loop, const std::vector<int>& factors);
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/ir/schedule/ir_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@ std::vector<Expr> IRSchedule::GetAllBlocks() const {
return results;
}

std::vector<stmt::StmtRef> IRSchedule::GetAllSchedules() const {
auto results = impl_->GetAllSchedules();
trace_.Append(ScheduleDesc::Step("GetAllSchedules", {}, {}, {}, results));
return results;
}

std::vector<Expr> IRSchedule::GetChildBlocks(const Expr& expr) const {
auto results = impl_->GetChildBlocks(expr);
trace_.Append(ScheduleDesc::Step(
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/ir/schedule/ir_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class IRSchedule {
//! Get all blocks stored in this ModuleExpr.
std::vector<Expr> GetAllBlocks() const;

//! Get all schedules stored in this ModuleExpr.
std::vector<stmt::StmtRef> GetAllSchedules() const;

//! Get a block with the specific name.
Expr GetBlock(const std::string& block_name) const;

Expand Down
42 changes: 42 additions & 0 deletions paddle/cinn/ir/schedule/ir_schedule_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,48 @@ struct FindBlocksVisitor {
std::vector<Expr> result{};
};

struct FindSchedulesVisitor {
explicit FindSchedulesVisitor(const std::string& schedule_name = "")
: schedule_name_(schedule_name) {}

std::vector<stmt::StmtRef> operator()(const stmt::BlockRef& block) {
VisitBlock(block);
return result;
}

private:
void VisitStmt(const stmt::StmtRef& stmt) {
if (!stmt.defined()) return;
if (!schedule_name_.empty() && !result.empty()) return;
if (stmt.isa<stmt::For>()) {
VisitBlock(stmt.as<stmt::For>()->body());
} else if (stmt.isa<stmt::Schedule>()) {
if (stmt.as<stmt::Schedule>()->name().substr(0, 4) != "root") {
auto schedule_block = stmt.as<stmt::Schedule>();
if (schedule_name_.empty() ||
schedule_block->name() == schedule_name_) {
result.emplace_back(stmt);
}
} else {
VisitBlock(stmt.as<stmt::Schedule>()->body());
}
} else if (stmt.isa<stmt::IfThenElse>()) {
VisitBlock(stmt.as<stmt::IfThenElse>()->true_case());
if (stmt.as<stmt::IfThenElse>()->false_case().defined())
VisitBlock(stmt.as<stmt::IfThenElse>()->false_case());
}
}

void VisitBlock(const stmt::BlockRef& block) {
for (const stmt::StmtRef& inner_stmt : block->stmts()) {
VisitStmt(inner_stmt);
}
}

std::string schedule_name_;
std::vector<stmt::StmtRef> result{};
};

struct FindLoopsVisitor {
explicit FindLoopsVisitor(const Expr& block) : block_(block) {}

Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/ir/schedule/schedule_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class ScheduleBase {
virtual std::vector<Expr> GetLoops(const Expr& block) const = 0;
virtual std::vector<Expr> GetLoops(const std::string& block_name) const = 0;
virtual std::vector<Expr> GetAllBlocks() const = 0;
virtual std::vector<stmt::StmtRef> GetAllSchedules() const = 0;
virtual std::vector<Expr> GetChildBlocks(const Expr& expr) const = 0;
virtual Expr GetBlock(const std::string& block_name) const = 0;

Expand Down
12 changes: 12 additions & 0 deletions paddle/cinn/ir/schedule/schedule_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/schedule_desc.pb.h"
#include "paddle/cinn/ir/stmt.h"
#include "paddle/cinn/utils/registry.h"
#include "paddle/cinn/utils/type_defs.h"

Expand All @@ -43,12 +44,23 @@ class ScheduleDesc {
absl::flat_hash_map<std::string, std::vector<Expr>> inputs;
utils::AttributeMap attrs;
std::vector<Expr> outputs;
std::vector<stmt::StmtRef> stmt_outputs;
Step() = default;
Step(std::string type_i,
absl::flat_hash_map<std::string, std::vector<Expr>> inputs_i,
utils::AttributeMap attrs_i,
std::vector<Expr> outputs_i)
: type(type_i), inputs(inputs_i), attrs(attrs_i), outputs(outputs_i) {}
Step(std::string type_i,
absl::flat_hash_map<std::string, std::vector<Expr>> inputs_i,
utils::AttributeMap attrs_i,
std::vector<Expr> outputs_i,
std::vector<stmt::StmtRef> stmt_outputs_i)
: type(type_i),
inputs(inputs_i),
attrs(attrs_i),
outputs(outputs_i),
stmt_outputs(stmt_outputs_i) {}
};

/**
Expand Down
Loading