Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions paddle/cinn/optim/eliminate_invariant_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,26 +58,6 @@ bool HasVarInIndicesOrValue(const ir::Expr& block, const ir::Var& var) {
return var_use.size() > 0;
}

// Check whether the block has an inplace update (e.g. a[i] = a[i] + b[i])
// by comparing between the block's read_buffers and write_buffers.
bool HasInplaceUpdate(const ir::Expr& block) {
auto* schedule_block = block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>();
std::set<std::string> read_buffer_names;
for (auto& buffer_range : schedule_block->read_buffers) {
read_buffer_names.insert(
buffer_range.As<ir::_BufferRange_>()->buffer.as_buffer()->name);
}
for (auto& buffer_range : schedule_block->write_buffers) {
auto& write_buffer_name =
buffer_range.As<ir::_BufferRange_>()->buffer.as_buffer()->name;
if (read_buffer_names.count(write_buffer_name) > 0) {
return true;
}
}
return false;
}

// Check whether the block is writing to a buffer whose scope is smaller than
// the For node's scope.
bool HasWriteToSmallerScope(const ir::Expr& block, const ir::For* for_node) {
Expand Down Expand Up @@ -112,7 +92,7 @@ struct InvariantLoopEliminator : public ir::IRMutator<> {
ir::Var loop_var = node->loop_var;
for (auto& block : child_blocks) {
if (HasVarInIndicesOrValue(block, loop_var)) return;
if (HasInplaceUpdate(block)) return;
if (ir::analyzer::IsReductionSBlock(block)) return;
if (node->is_binded()) {
if (HasWriteToSmallerScope(block, node)) return;
if (!ir::analyzer::GetConsumerSBlocks(block, *root_).empty()) return;
Expand Down
4 changes: 2 additions & 2 deletions paddle/cinn/optim/eliminate_invariant_loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ namespace optim {
* (1) The loop variable is not used in any load/store indices or computation
* within child schedule blocks. This ensures that the loop writes the same
* value to the same index in each iteration.
* (2) It doesn't contain any inplace operations, e.g. a[0] = a[0] + b[0]. In
* the presence of inplace operations, the loop count matters.
* (2) It is not a Reduce (e.g. a[0] = a[0] + b[k]), because for a Reduce, even
* though its indices don't change in each iteration, its value changes.
*
* We can eliminate a bound loop if it also satisfies rule (3) and (4):
* (3) It doesn't write to the local buffer (for thread-bound loop) or shared
Expand Down
Loading