Skip to content

Commit fba2e98

Browse files
authored
[CINN][Backend Pass Update No.12] Update transform_gpu_forloop pass (#70883)
* Update transform_gpu_forloop pass * Update op_lowering_impl.cc * Update CudaSyncThreadsDropIfThenElse pass * Disable EliminateCommonGlobalMemoryRead pass
1 parent ab796ea commit fba2e98

File tree

6 files changed

+572
-258
lines changed

6 files changed

+572
-258
lines changed

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "paddle/cinn/optim/eliminate_common_global_memory_read.h"
4141
#include "paddle/cinn/optim/schedule_block_dce.h"
4242
#include "paddle/cinn/optim/transform_gpu_forloop.h"
43+
#include "paddle/cinn/pass/pass_manager.h"
4344
#include "paddle/common/ddim.h"
4445
#include "paddle/common/enforce.h"
4546
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
@@ -393,12 +394,26 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
393394
[&](common::NVGPUArch) {
394395
#ifdef CINN_WITH_CUDA
395396
// optim::EliminateCommonGlobalMemoryRead(&(func_body));
396-
optim::OptimizeExprGPU(&(func_body));
397+
ir::stmt::BlockRef func_body_block =
398+
ir::ConvertExprBlockToStmtBlock(func_body);
399+
LOG(INFO) << "Before OptimizeExprGPU in op_lowering_impl: \n"
400+
<< func_body_block;
401+
optim::OptimizeExprGPU(func_body_block);
402+
LOG(INFO) << "After OptimizeExprGPU in op_lowering_impl: \n"
403+
<< func_body_block;
404+
func_body = ir::ConvertStmtBlockToExprBlock(func_body_block);
397405
#endif
398406
},
399407
[&](std::variant<common::HygonDCUArchHIP, common::HygonDCUArchSYCL>) {
400408
// optim::EliminateCommonGlobalMemoryRead(&(func_body));
401-
optim::OptimizeExprGPU(&(func_body));
409+
ir::stmt::BlockRef func_body_block =
410+
ir::ConvertExprBlockToStmtBlock(func_body);
411+
LOG(INFO) << "Before OptimizeExprGPU in op_lowering_impl: \n"
412+
<< func_body_block;
413+
optim::OptimizeExprGPU(func_body_block);
414+
LOG(INFO) << "After OptimizeExprGPU in op_lowering_impl: \n"
415+
<< func_body_block;
416+
func_body = ir::ConvertStmtBlockToExprBlock(func_body_block);
402417
});
403418
}
404419

paddle/cinn/optim/optimize.cc

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,17 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
8585
#ifdef CINN_WITH_CUDA
8686
ir::SetCudaAxisInfo(copied);
8787
if (remove_gpu_for_loops) {
88-
RemoveGpuForLoops(copied);
88+
LOG(INFO) << "Before removing GPU for loops:\n" << copied;
89+
FuncPassManager func_pass_manager;
90+
func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass());
91+
func_pass_manager.Run(copied);
92+
LOG(INFO) << "After removing GPU for loops:\n" << copied;
8993
}
90-
CudaSyncThreadsDropIfThenElse(copied);
94+
VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied;
95+
BlockPassManager blk_pass_manager;
96+
blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass());
97+
blk_pass_manager.Run(copied->body_block);
98+
VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied;
9199
FuncPassManager func_pass_manager;
92100
VLOG(10) << "Before Optimize TransBufferWithDynamicShape:" << copied;
93101
func_pass_manager.AddPass(CreateTransBufferWithDynamicShapePass());
@@ -99,10 +107,17 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
99107
#ifdef CINN_WITH_HIP
100108
ir::SetCudaAxisInfo(copied);
101109
if (remove_gpu_for_loops) {
102-
RemoveGpuForLoops(copied);
110+
LOG(INFO) << "Before removing GPU for loops:\n" << copied;
111+
FuncPassManager func_pass_manager;
112+
func_pass_manager.AddPass(CreateRemoveGpuForLoopsPass());
113+
func_pass_manager.Run(copied);
114+
LOG(INFO) << "After removing GPU for loops:\n" << copied;
103115
}
104-
CudaSyncThreadsDropIfThenElse(copied);
105-
// CudaTransBufferWithDynamicShape(&copied);
116+
VLOG(10) << "Before Optimize CudaSyncThreadsDropIfThenElse:" << copied;
117+
BlockPassManager blk_pass_manager;
118+
blk_pass_manager.AddPass(CreateCudaSyncThreadsDropIfThenElsePass());
119+
blk_pass_manager.Run(copied->body_block);
120+
VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied;
106121
#endif
107122
},
108123
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED },

paddle/cinn/optim/replace_var_with_expr.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,13 @@ struct ReplaceVarWithExprMutator : public ir::IRMutator<>,
118118
ir::IRMutator<>::Visit(&var->upper_bound, &var->upper_bound);
119119
}
120120
}
121+
122+
std::vector<Expr> iter_values = stmt->iter_values();
123+
for (ir::Expr& iter_value : iter_values) {
124+
ir::IRMutator<>::Visit(&iter_value, &iter_value);
125+
}
126+
stmt->set_iter_values(iter_values);
127+
121128
std::vector<Expr> new_read_buffers = stmt->read_buffers();
122129
for (Expr& read_buffer : new_read_buffers) {
123130
ir::IRMutator<>::Visit(&read_buffer, &read_buffer);

0 commit comments

Comments
 (0)