Skip to content

Commit 2306ba6

Browse files
authored
[CINN] Eliminate loops for inplace operation (#70581)
1 parent 0c5a5a2 commit 2306ba6

File tree

2 files changed

+3
-23
lines changed

2 files changed

+3
-23
lines changed

paddle/cinn/optim/eliminate_invariant_loop.cc

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,6 @@ bool HasVarInIndicesOrValue(const ir::Expr& block, const ir::Var& var) {
5858
return var_use.size() > 0;
5959
}
6060

61-
// Check whether the block has an inplace update (e.g. a[i] = a[i] + b[i])
62-
// by comparing between the block's read_buffers and write_buffers.
63-
bool HasInplaceUpdate(const ir::Expr& block) {
64-
auto* schedule_block = block.As<ir::ScheduleBlockRealize>()
65-
->schedule_block.As<ir::ScheduleBlock>();
66-
std::set<std::string> read_buffer_names;
67-
for (auto& buffer_range : schedule_block->read_buffers) {
68-
read_buffer_names.insert(
69-
buffer_range.As<ir::_BufferRange_>()->buffer.as_buffer()->name);
70-
}
71-
for (auto& buffer_range : schedule_block->write_buffers) {
72-
auto& write_buffer_name =
73-
buffer_range.As<ir::_BufferRange_>()->buffer.as_buffer()->name;
74-
if (read_buffer_names.count(write_buffer_name) > 0) {
75-
return true;
76-
}
77-
}
78-
return false;
79-
}
80-
8161
// Check whether the block is writing to a buffer whose scope is smaller than
8262
// the For node's scope.
8363
bool HasWriteToSmallerScope(const ir::Expr& block, const ir::For* for_node) {
@@ -112,7 +92,7 @@ struct InvariantLoopEliminator : public ir::IRMutator<> {
11292
ir::Var loop_var = node->loop_var;
11393
for (auto& block : child_blocks) {
11494
if (HasVarInIndicesOrValue(block, loop_var)) return;
115-
if (HasInplaceUpdate(block)) return;
95+
if (ir::analyzer::IsReductionSBlock(block)) return;
11696
if (node->is_binded()) {
11797
if (HasWriteToSmallerScope(block, node)) return;
11898
if (!ir::analyzer::GetConsumerSBlocks(block, *root_).empty()) return;

paddle/cinn/optim/eliminate_invariant_loop.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ namespace optim {
2626
* (1) The loop variable is not used in any load/store indices or computation
2727
* within child schedule blocks. This ensures that the loop writes the same
2828
* value to the same index in each iteration.
29-
* (2) It doesn't contain any inplace operations, e.g. a[0] = a[0] + b[0]. In
30-
* the presence of inplace operations, the loop count matters.
29+
* (2) It is not a Reduce (e.g. a[0] = a[0] + b[k]), because for a Reduce, even
30+
* though its indices don't change in each iteration, its value changes.
3131
*
3232
* We can eliminate a bound loop if it also satisfies rule (3) and (4):
3333
* (3) It doesn't write to the local buffer (for thread-bound loop) or shared

0 commit comments

Comments
 (0)