1313// limitations under the License.
1414
1515#include " paddle/cinn/ir/group_schedule/tactic/compute_at_reduction_tactic.h"
16+ #include " paddle/cinn/ir/ir_analyzer/data_dependency_graph.h"
1617#include " paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
1718#include " paddle/cinn/ir/utils/ir_compare.h"
19+ #include " paddle/cinn/ir/utils/stmt_converter.h"
1820#include " paddle/cinn/optim/replace_var_with_expr.h"
1921
2022namespace cinn {
@@ -104,7 +106,6 @@ class ComputeAtReductionTactic final : public ScheduleTactic {
104106 // A copy of the IRSchedule and ScheduleBlockGraph, with all loop vars
105107 // unifiedly rewritten to the form `$<loop_index>` (e.g. $0, $1).
106108 std::unique_ptr<ir::IRSchedule> sch_;
107- std::unique_ptr<ir::ScheduleBlockGraph> graph_;
108109
109110 // Cache of the results of GetUnifiedControlFlow, GetLoopVariantLoads,
110111 // GetSerialLoopExtent and IsReductionSBlock because these functions are too
@@ -241,7 +242,6 @@ void ComputeAtReductionTactic::Init(ScheduleContext* context,
241242 serial_loop_extent_.clear ();
242243
243244 sch_ = std::make_unique<ir::IRSchedule>(*sch);
244- graph_ = std::make_unique<ir::ScheduleBlockGraph>(*sch_);
245245
246246 for (auto & block : sch_->GetAllBlocks ()) {
247247 // Replace loop_vars to the unified form `$<loop_index>`
@@ -373,10 +373,8 @@ std::vector<std::string>
373373ComputeAtReductionTactic::GetDependencyHarzardFreeBlocks (
374374 ir::IRSchedule* sch, const std::string& block_id) {
375375 std::vector<std::string> results;
376- std::vector<ir::Expr> blocks = sch->GetAllBlocks ();
377- auto * graph_node = graph_->RetrieveNode (block_id);
378- std::unordered_set<std::string> upstreams = graph_node->UpstreamNodes ();
379- std::unordered_set<std::string> downstreams = graph_node->DownstreamNodes ();
376+ std::vector<stmt::StmtRef> stmts = sch->GetAllSchedules ();
377+ analyzer::DataDependencyGraph dep_graph (stmts);
380378
381379 // Find the position of the current block in the graph, then search upwards
382380 // and downwards until a denepency harzard is met.
@@ -397,25 +395,25 @@ ComputeAtReductionTactic::GetDependencyHarzardFreeBlocks(
397395 // C has denepency harzard with B because it directly depends on B. C also has
398396 // dependency harzard with A, because if we move C to the position of A, we
399397 // will violate the dependency of B->C. C is only harzard-free with D and E.
400- auto this_it =
401- std::find_if (blocks .begin (), blocks .end (), [&](const ir::Expr& block ) {
402- return analyzer::GetBlockName (block ) == block_id;
398+ auto this_it = std::find_if (
399+ stmts .begin (), stmts .end (), [&](const ir::stmt::StmtRef& stmt ) {
400+ return stmt. as <stmt::Schedule>()-> name ( ) == block_id;
403401 });
404402
405403 // Search upwards
406404 auto this_it_rev = std::make_reverse_iterator (this_it);
407- for (auto it = this_it_rev; it != blocks .rend (); ++it) {
408- std::string other_id = analyzer::GetBlockName (*it);
405+ for (auto it = this_it_rev; it != stmts .rend (); ++it) {
406+ std::string other_id = (*it). as <stmt::Schedule>()-> name ( );
409407 // As a special case, we can ignore the `reduce_init` in front of Reduce.
410408 if (IsReduceInitTensorName (other_id)) continue ;
411- if (upstreams. count (other_id) > 0 ) break ;
409+ if (dep_graph. HasDependency (*it, *this_it) == analyzer::DepKind::DEP ) break ;
412410 results.push_back (other_id);
413411 }
414412
415413 // Search downwards
416- for (auto it = this_it + 1 ; it != blocks .end (); ++it) {
417- std::string other_id = analyzer::GetBlockName (*it);
418- if (downstreams. count (other_id) > 0 ) break ;
414+ for (auto it = this_it + 1 ; it != stmts .end (); ++it) {
415+ std::string other_id = (*it). as <stmt::Schedule>()-> name ( );
416+ if (dep_graph. HasDependency (*this_it, *it) == analyzer::DepKind::DEP ) break ;
419417 results.push_back (other_id);
420418 }
421419
0 commit comments