Skip to content

Commit 7779094

Browse files
committed
Fix
2 parents 6542244 + c2d2b66 commit 7779094

File tree

349 files changed

+8226
-7198
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

349 files changed

+8226
-7198
lines changed

CMakeLists.txt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -593,13 +593,6 @@ if(WITH_RPC)
593593
OFF
594594
CACHE BOOL "Disable WITH_RPC when compiling with XPU" FORCE)
595595
endif()
596-
if(WITH_CINN AND WITH_RPC)
597-
message(
598-
WARNING "Disable WITH_RPC when compiling with CINN. Force WITH_RPC=OFF.")
599-
set(WITH_RPC
600-
OFF
601-
CACHE BOOL "Disable WITH_RPC when compiling with CINN" FORCE)
602-
endif()
603596
endif()
604597

605598
if(WITH_MPI)

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.h"
1616

17+
#include <chrono>
1718
#include "paddle/common/errors.h"
1819
#include "paddle/common/flags.h"
1920
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
@@ -243,6 +244,7 @@ int64_t GetOpCount(const ::pir::Operation* op) {
243244
void ApplyCinnPass(::pir::Program* program,
244245
const std::function<std::shared_ptr<pir::PassManager>()>&
245246
CreatePassManager) {
247+
const uint32_t origin_num_ops = program->num_ops();
246248
PirToPyCodeConverter(program)
247249
.file_name("original_programs.py")
248250
.dump_symbolic_shape(FLAGS_logging_pir_py_code_dump_symbolic_dims)
@@ -268,7 +270,19 @@ void ApplyCinnPass(::pir::Program* program,
268270
<< pir::CustomPrintHelper(*program, shape_analysis.PrintHook())
269271
<< std::endl;
270272
}
273+
274+
auto start = std::chrono::high_resolution_clock::now();
271275
ApplyCinnLowerPass(program, CreatePassManager);
276+
auto end = std::chrono::high_resolution_clock::now();
277+
auto duration = std::chrono::duration_cast<std::chrono::seconds>(end - start);
278+
LOG(INFO) << "Time of lowering and compiling program: ***** [ "
279+
<< duration.count() << " ] ***** seconds.";
280+
281+
const uint32_t new_num_ops = program->num_ops();
282+
LOG(INFO) << "Number of ops in the original program is: " << origin_num_ops
283+
<< ", after lowering it becomes: " << new_num_ops
284+
<< ". (compression ratio: " << new_num_ops << "/" << origin_num_ops
285+
<< " = " << static_cast<float>(new_num_ops) / origin_num_ops << ")";
272286
}
273287

274288
} // namespace cinn::dialect::ir

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -915,8 +915,9 @@ class SqueezeOpPattern
915915
in_shape[i]));
916916
}
917917
}
918-
919-
ReplaceWithCinnReshapeOp(op, rewriter, output_shape);
918+
auto cinn_reshape = rewriter.Build<cinn::dialect::ReshapeOp>(
919+
op->operand_source(0), output_shape);
920+
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
920921
rewriter.EraseOp(op);
921922

922923
return true;
@@ -956,7 +957,6 @@ class UnsqueezeOpPattern
956957
output_shape.push_back(1);
957958
}
958959
}
959-
960960
ReplaceWithCinnReshapeOp(op, rewriter, output_shape);
961961
rewriter.EraseOp(op);
962962

paddle/cinn/hlir/pe/ir_schedule_pe.cc

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,10 @@ std::vector<cinn::common::CINNValue> IRGpuScheduleMatMul(
218218
vec_ast.emplace_back(temp);
219219
}
220220
}
221-
CHECK(!vec_ast.empty());
221+
PADDLE_ENFORCE_EQ(vec_ast.empty(),
222+
false,
223+
phi::errors::InvalidArgument(
224+
"The vector 'vec_ast' should not be empty."));
222225
ir::ModuleExpr mod_expr(vec_ast);
223226
ir::IRSchedule ir_sch(mod_expr);
224227
ir_sch.MergeExprs();
@@ -311,9 +314,15 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // NOLINT
311314

312315
// collect block names
313316
auto get_block_name = [](ir::Expr expr) {
314-
CHECK(expr.As<ir::ScheduleBlockRealize>());
315-
CHECK(expr.As<ir::ScheduleBlockRealize>()
316-
->schedule_block.As<ir::ScheduleBlock>());
317+
PADDLE_ENFORCE_NOT_NULL(
318+
expr.As<ir::ScheduleBlockRealize>(),
319+
phi::errors::InvalidArgument(
320+
"The expression must be convertible to ir::ScheduleBlockRealize."));
321+
PADDLE_ENFORCE_NOT_NULL(expr.As<ir::ScheduleBlockRealize>()
322+
->schedule_block.As<ir::ScheduleBlock>(),
323+
phi::errors::InvalidArgument(
324+
"Failed to convert ir::ScheduleBlockRealize to "
325+
"ir::ScheduleBlock."));
317326
return expr.As<ir::ScheduleBlockRealize>()
318327
->schedule_block.As<ir::ScheduleBlock>()
319328
->name;
@@ -488,9 +497,16 @@ void IRGpuScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
488497
auto out_block = ir_sch.GetBlock(out->name);
489498
auto root_block = ir_sch.GetRootBlock(out_block);
490499

491-
CHECK(out_block->as<ir::ScheduleBlockRealize>());
492-
CHECK(out_block->as<ir::ScheduleBlockRealize>()
493-
->schedule_block->as<ir::ScheduleBlock>());
500+
PADDLE_ENFORCE_NOT_NULL(
501+
out_block->as<ir::ScheduleBlockRealize>(),
502+
phi::errors::InvalidArgument(
503+
"The out_block must be convertible to ir::ScheduleBlockRealize."));
504+
PADDLE_ENFORCE_NOT_NULL(
505+
out_block->as<ir::ScheduleBlockRealize>()
506+
->schedule_block->as<ir::ScheduleBlock>(),
507+
phi::errors::InvalidArgument(
508+
"The schedule_block within ir::ScheduleBlockRealize must be "
509+
"convertible to ir::ScheduleBlock."));
494510

495511
// create var
496512
auto var = ir::Var(ir::Expr(0), ir::Expr(1), cinn::common::UniqName("i"));
@@ -499,9 +515,16 @@ void IRGpuScheduleBlockReduceInternal(ir::IRSchedule &ir_sch, // NOLINT
499515
->schedule_block->as<ir::ScheduleBlock>()
500516
->iter_vars.push_back(var);
501517

502-
CHECK(root_block->as<ir::ScheduleBlockRealize>());
503-
CHECK(root_block->as<ir::ScheduleBlockRealize>()
504-
->schedule_block->as<ir::ScheduleBlock>());
518+
PADDLE_ENFORCE_NOT_NULL(
519+
root_block->as<ir::ScheduleBlockRealize>(),
520+
phi::errors::InvalidArgument(
521+
"The root_block must be convertible to ir::ScheduleBlockRealize."));
522+
PADDLE_ENFORCE_NOT_NULL(
523+
root_block->as<ir::ScheduleBlockRealize>()
524+
->schedule_block->as<ir::ScheduleBlock>(),
525+
phi::errors::InvalidArgument(
526+
"The schedule_block within ir::ScheduleBlockRealize must be "
527+
"convertible to ir::ScheduleBlock."));
505528

506529
// create for and block node
507530
auto for_node = ir::For::Make(var,
@@ -572,14 +595,20 @@ void IRGpuScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
572595
<< ir_sch.GetModule().GetExprs().at(0);
573596
int tmp_put_shape_size_without_reduce = 0;
574597
for (auto i : tmp_out->shape) {
575-
CHECK(i.is_constant());
598+
PADDLE_ENFORCE_EQ(i.is_constant(),
599+
true,
600+
phi::errors::InvalidArgument(
601+
"The value must be a constant but it is not."));
576602
if (i.as_int32() != 1) tmp_put_shape_size_without_reduce++;
577603
}
578604
tmp_put_shape_size_without_reduce--;
579605
// fuse last parallel dimension
580606
int reduce_temp_out_shape_size = 0;
581607
for (auto i : reduce_tmp_out->shape) {
582-
CHECK(i.is_constant());
608+
PADDLE_ENFORCE_EQ(i.is_constant(),
609+
true,
610+
phi::errors::InvalidArgument(
611+
"The value must be a constant but it is not."));
583612
if (i.as_int32() != 1) reduce_temp_out_shape_size++;
584613
}
585614

@@ -623,9 +652,16 @@ void IRGpuScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
623652
auto out_block = ir_sch.GetBlock(out->name);
624653
auto root_block = ir_sch.GetRootBlock(out_block);
625654

626-
CHECK(out_block->as<ir::ScheduleBlockRealize>());
627-
CHECK(out_block->as<ir::ScheduleBlockRealize>()
628-
->schedule_block->as<ir::ScheduleBlock>());
655+
PADDLE_ENFORCE_NOT_NULL(
656+
out_block->as<ir::ScheduleBlockRealize>(),
657+
phi::errors::InvalidArgument(
658+
"The out_block must be convertible to ir::ScheduleBlockRealize."));
659+
PADDLE_ENFORCE_NOT_NULL(
660+
out_block->as<ir::ScheduleBlockRealize>()
661+
->schedule_block->as<ir::ScheduleBlock>(),
662+
phi::errors::InvalidArgument(
663+
"The schedule_block within ir::ScheduleBlockRealize must be "
664+
"convertible to ir::ScheduleBlock."));
629665

630666
// create var
631667
auto var = ir::Var(ir::Expr(0), ir::Expr(1), cinn::UniqName("i"));
@@ -634,9 +670,16 @@ void IRGpuScheduleBlockReduce(ir::IRSchedule &ir_sch, // NOLINT
634670
->schedule_block->as<ir::ScheduleBlock>()
635671
->iter_vars.push_back(var);
636672

637-
CHECK(root_block->as<ir::ScheduleBlockRealize>());
638-
CHECK(root_block->as<ir::ScheduleBlockRealize>()
639-
->schedule_block->as<ir::ScheduleBlock>());
673+
PADDLE_ENFORCE_NOT_NULL(
674+
root_block->as<ir::ScheduleBlockRealize>(),
675+
phi::errors::InvalidArgument(
676+
"The root_block must be convertible to ir::ScheduleBlockRealize."));
677+
PADDLE_ENFORCE_NOT_NULL(
678+
root_block->as<ir::ScheduleBlockRealize>()
679+
->schedule_block->as<ir::ScheduleBlock>(),
680+
phi::errors::InvalidArgument(
681+
"The schedule_block within ir::ScheduleBlockRealize must be "
682+
"convertible to ir::ScheduleBlock."));
640683

641684
// create for and block node
642685
auto for_node = ir::For::Make(var,
@@ -1010,9 +1053,16 @@ void IRGpuTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
10101053
auto out_block = ir_sch.GetBlock(out->name);
10111054
auto root_block = ir_sch.GetRootBlock(out_block);
10121055

1013-
CHECK(out_block->as<ir::ScheduleBlockRealize>());
1014-
CHECK(out_block->as<ir::ScheduleBlockRealize>()
1015-
->schedule_block->as<ir::ScheduleBlock>());
1056+
PADDLE_ENFORCE_NOT_NULL(
1057+
out_block->as<ir::ScheduleBlockRealize>(),
1058+
phi::errors::InvalidArgument(
1059+
"The out_block must be convertible to ir::ScheduleBlockRealize."));
1060+
PADDLE_ENFORCE_NOT_NULL(
1061+
out_block->as<ir::ScheduleBlockRealize>()
1062+
->schedule_block->as<ir::ScheduleBlock>(),
1063+
phi::errors::InvalidArgument(
1064+
"The schedule_block within ir::ScheduleBlockRealize must be "
1065+
"convertible to ir::ScheduleBlock."));
10161066

10171067
// create var
10181068
// auto var = ir::Var(ir::Expr(0), ir::Expr(1), "i_0");
@@ -1022,9 +1072,16 @@ void IRGpuTwoStepReduceSchedule(ir::IRSchedule &ir_sch, // NOLINT
10221072
->schedule_block->as<ir::ScheduleBlock>()
10231073
->iter_vars.push_back(var);
10241074

1025-
CHECK(root_block->as<ir::ScheduleBlockRealize>());
1026-
CHECK(root_block->as<ir::ScheduleBlockRealize>()
1027-
->schedule_block->as<ir::ScheduleBlock>());
1075+
PADDLE_ENFORCE_NOT_NULL(
1076+
root_block->as<ir::ScheduleBlockRealize>(),
1077+
phi::errors::InvalidArgument(
1078+
"The root_block must be convertible to ir::ScheduleBlockRealize."));
1079+
PADDLE_ENFORCE_NOT_NULL(
1080+
root_block->as<ir::ScheduleBlockRealize>()
1081+
->schedule_block->as<ir::ScheduleBlock>(),
1082+
phi::errors::InvalidArgument(
1083+
"The schedule_block within ir::ScheduleBlockRealize must be "
1084+
"convertible to ir::ScheduleBlock."));
10281085

10291086
// create for and block node
10301087
auto for_node = ir::For::Make(var,
@@ -1205,7 +1262,9 @@ void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, // NOLINT
12051262
void IRCudaScheduleDepthwiseConv(ir::IRSchedule &ir_sch, // NOLINT
12061263
const std::vector<ir::Expr> &tensors) {
12071264
if (tensors.size() == 3U) {
1208-
CHECK(tensors[1].as_tensor());
1265+
PADDLE_ENFORCE_NOT_NULL(tensors[1].as_tensor(),
1266+
phi::errors::InvalidArgument(
1267+
"The tensor at index 1 must not be null."));
12091268
auto input_pad = ir_sch.GetBlock(tensors[1].as_tensor_ref()->name);
12101269
ir_sch.ComputeInline(input_pad);
12111270
}

paddle/cinn/hlir/pe/reduction.cc

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,40 +1205,6 @@ std::string DiscreteReduceExternalFuncName(const ir::Expr& op,
12051205
return "";
12061206
}
12071207

1208-
std::string IntervalReduceExternalFuncName(const ir::Expr& op,
1209-
const ir::Expr& tensor) {
1210-
CHECK_NOTNULL(tensor.as_tensor());
1211-
if (op.As<ir::Add>()) {
1212-
if (tensor.as_tensor()->type().is_bool()) {
1213-
return "cinn_interval_reduce_any_internal_shm";
1214-
}
1215-
return "cinn_interval_reduce_sum" +
1216-
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
1217-
} else if (op.As<ir::Mul>()) {
1218-
if (tensor.as_tensor()->type().is_bool()) {
1219-
return "cinn_interval_reduce_all_internal_shm";
1220-
}
1221-
return "cinn_interval_reduce_prod" +
1222-
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
1223-
} else if (op.As<ir::Max>()) {
1224-
return "cinn_interval_reduce_max" +
1225-
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
1226-
} else if (op.As<ir::Min>()) {
1227-
return "cinn_interval_reduce_min" +
1228-
Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm";
1229-
} else if (op.As<ir::And>()) {
1230-
return "cinn_interval_reduce_all_internal_shm";
1231-
} else if (op.As<ir::Or>()) {
1232-
return "cinn_interval_reduce_any_internal_shm";
1233-
} else {
1234-
std::stringstream ss;
1235-
ss << op;
1236-
PADDLE_THROW(::common::errors::InvalidArgument(
1237-
"Reduce type %s is not supported yet!", ss.str()));
1238-
}
1239-
return "";
1240-
}
1241-
12421208
} // namespace pe
12431209
} // namespace hlir
12441210
} // namespace cinn

paddle/cinn/hlir/pe/reduction.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,9 +474,6 @@ std::string CrossThreadReduceExternalFuncName(const ir::Expr& op,
474474
std::string DiscreteReduceExternalFuncName(const ir::Expr& op,
475475
const ir::Expr& tensor);
476476

477-
std::string IntervalReduceExternalFuncName(const ir::Expr& op,
478-
const ir::Expr& tensor);
479-
480477
std::string Type2StrForReduce(cinn::common::Type type);
481478
} // namespace pe
482479
} // namespace hlir

paddle/cinn/ir/group_schedule/config/group_tile_config.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,21 @@ std::shared_ptr<ScheduleConfig::BaseInfo> InitBasicInfo(
9696
base_info->reduce_tensor_names = group_info->reduce_var_names;
9797
base_info->shared_var_names = group_info->shared_var_names;
9898
base_info->direct_output_var_names = group_info->direct_output_var_names;
99-
base_info->data_space = group_info->data_space;
99+
base_info->data_rank = group_info->data_space.size();
100100
base_info->loop_strides = group_info->loop_strides;
101101

102102
std::set<int64_t> reduce_dim_loc;
103103
for (int64_t dim : group_info->reduce_axis) {
104104
if (dim < 0) {
105-
dim += base_info->data_space.size();
105+
dim += base_info->data_rank;
106106
}
107107
base_info->reduce_axis.push_back(dim);
108108
reduce_dim_loc.insert(dim);
109109
}
110110

111111
base_info->spatial_numel = 1;
112112
base_info->reduce_numel = 1;
113-
for (int64_t i = 0; i < base_info->data_space.size(); ++i) {
113+
for (int64_t i = 0; i < base_info->data_rank; ++i) {
114114
if (reduce_dim_loc.count(i)) {
115115
if (group_info->data_space[i] == -1) base_info->has_dynamic_reduce = true;
116116
base_info->reduce_numel *= group_info->data_space[i];
@@ -121,7 +121,7 @@ std::shared_ptr<ScheduleConfig::BaseInfo> InitBasicInfo(
121121
}
122122
}
123123
base_info->is_reduce_all =
124-
(base_info->reduce_axis.size() == base_info->data_space.size());
124+
(base_info->reduce_axis.size() == base_info->data_rank);
125125

126126
for (int64_t i = 0; i < group_info->data_space.size(); ++i) {
127127
if (group_info->data_space[i] == 1) continue;

paddle/cinn/ir/group_schedule/config/group_tile_config.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct ScheduleConfig {
3232
struct BaseInfo {
3333
std::vector<int64_t> reduce_axis;
3434
std::vector<int64_t> loop_strides;
35-
std::vector<int64_t> data_space;
35+
int64_t data_rank;
3636
int64_t reduce_numel;
3737
int64_t spatial_numel;
3838
bool has_dynamic_spatial{false};

paddle/cinn/ir/group_schedule/search/config_searcher.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,10 +198,10 @@ bool CandidateGenerator::IsValid(const CandidateType& candidate) const {
198198
}
199199

200200
ScheduleConfigSearcher::ScheduleConfigSearcher(
201-
std::unique_ptr<BaseObjectiveFunc> objective_func,
201+
std::vector<std::unique_ptr<BaseObjectiveFunc>> objective_funcs,
202202
const std::vector<std::pair<int, int>>& candidate_range,
203203
const std::vector<ConstraintFunc>& contraints)
204-
: objective_func_(std::move(objective_func)),
204+
: objective_funcs_(std::move(objective_funcs)),
205205
candidate_range_(candidate_range),
206206
contraints_(contraints) {}
207207

@@ -212,7 +212,10 @@ std::pair<ScoreType, CandidateType> ScheduleConfigSearcher::Search(
212212
std::vector<CandidateType> candidates = candidate_generator.Candidates();
213213
VLOG(6) << "Candidate num = " << candidates.size();
214214
for (const auto& candidate : candidates) {
215-
ScoreType score = (*objective_func_)(candidate);
215+
ScoreType score = 0;
216+
for (auto& objective_func_ : objective_funcs_) {
217+
score += (*objective_func_)(candidate);
218+
}
216219
VLOG(6) << "Candidate: [" << utils::Join<int64_t>(candidate, ", ") << "]";
217220
VLOG(6) << "Score = " << score;
218221
records_[score] = candidate;

paddle/cinn/ir/group_schedule/search/config_searcher.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ class CandidateGenerator {
8383
class ScheduleConfigSearcher {
8484
public:
8585
ScheduleConfigSearcher(
86-
std::unique_ptr<BaseObjectiveFunc> objective_func,
86+
std::vector<std::unique_ptr<BaseObjectiveFunc>> objective_funcs,
8787
const std::vector<std::pair<int, int>>& candidate_range,
8888
const std::vector<ConstraintFunc>& contraints = {});
8989

9090
std::pair<ScoreType, CandidateType> Search(bool is_search_minimun = true);
9191

9292
private:
93-
std::unique_ptr<BaseObjectiveFunc> objective_func_;
93+
std::vector<std::unique_ptr<BaseObjectiveFunc>> objective_funcs_;
9494
std::vector<ConstraintFunc> contraints_;
9595
std::vector<std::pair<int, int>> candidate_range_;
9696

0 commit comments

Comments
 (0)