Skip to content

Commit 0708454

Browse files
authored
Merge branch 'develop' into develop
2 parents 42c442f + 4a071e2 commit 0708454

File tree

761 files changed

+20999
-9203
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

761 files changed

+20999
-9203
lines changed

cmake/external/xpu.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ endif()
3232
if(NOT DEFINED XPU_XHPC_BASE_DATE)
3333
set(XPU_XHPC_BASE_DATE "eb35/20241015")
3434
endif()
35-
set(XPU_XCCL_BASE_VERSION "1.2.11d")
35+
set(XPU_XCCL_BASE_VERSION "1.2.11e")
3636
if(NOT DEFINED XPU_XFT_BASE_VERSION)
3737
set(XPU_XFT_BASE_VERSION "20230602")
3838
endif()

cmake/inference_lib.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ endif()
304304

305305
copy(
306306
inference_lib_dist
307-
SRCS ${CMAKE_BINARY_DIR}/paddle/fluid/framework/framework.pb.h
307+
SRCS ${CMAKE_BINARY_DIR}/paddle/phi/core/framework/framework.pb.h
308308
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/internal)
309309
copy(
310310
inference_lib_dist

paddle/cinn/adt/equation_solver.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ std::unordered_map<Variable, Value> InferValuesImpl(
3737
PADDLE_ENFORCE_EQ(
3838
ctx->HasValue(in_variable),
3939
true,
40-
phi::errors::NotFound("The param id's out_iter must contain "
41-
"its in_iter's value"));
40+
::common::errors::NotFound("The param id's out_iter must contain "
41+
"its in_iter's value"));
4242
return {{out_iter.value(), ctx->GetValue(in_variable)}};
4343
}
4444

@@ -49,8 +49,8 @@ std::unordered_map<Variable, Value> InferValuesImpl(
4949
PADDLE_ENFORCE_EQ(
5050
ctx->HasValue(in_variable),
5151
true,
52-
phi::errors::NotFound("The param id's out_iter must contain "
53-
"its in_iter's value"));
52+
::common::errors::NotFound("The param id's out_iter must contain "
53+
"its in_iter's value"));
5454
return {{out_index.value(), ctx->GetValue(in_variable)}};
5555
}
5656

@@ -215,7 +215,7 @@ std::unordered_map<Variable, Value> InferValuesImpl(
215215
PADDLE_ENFORCE_EQ(
216216
ret.emplace(out_msg_in_indexes.value()->at(i), value).second,
217217
true,
218-
phi::errors::AlreadyExists([&]() {
218+
::common::errors::AlreadyExists([&]() {
219219
std::ostringstream oss;
220220
oss << "Failed to insert the variable '"
221221
<< "out_msg_in_indexes.value()->at(" << i
@@ -229,7 +229,7 @@ std::unordered_map<Variable, Value> InferValuesImpl(
229229
if (out_index.has_value()) {
230230
PADDLE_ENFORCE_EQ(ret.emplace(out_index.value(), value).second,
231231
true,
232-
phi::errors::AlreadyExists([&]() {
232+
::common::errors::AlreadyExists([&]() {
233233
std::ostringstream oss;
234234
oss << "Failed to insert the variable '"
235235
<< "out_index.value()"
@@ -306,7 +306,9 @@ void SolveEquations(
306306
tValueInferSuccess<bool> has_unique_value =
307307
MergeInferedValuesIntoCtx(function, ctx);
308308
PADDLE_ENFORCE_EQ(
309-
has_unique_value.value(), true, phi::errors::InvalidArgument([&]() {
309+
has_unique_value.value(),
310+
true,
311+
::common::errors::InvalidArgument([&]() {
310312
std::ostringstream oss;
311313
oss << "Failed to merge inferred values into the context for "
312314
"function '"

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,24 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
262262
ir::CallType::Extern,
263263
ir::FunctionRef(),
264264
0);
265+
266+
// create memset calls for temp_spaces if needed
267+
std::vector<ir::Expr> call_kernel_stmts;
268+
for (auto &temp_space : func_node->temp_spaces) {
269+
if (temp_space.need_zero_init()) {
270+
ir::Expr size = common::cast(temp_space.size(), common::UInt(64));
271+
ir::Expr call_get_arg =
272+
lang::CallExtern(runtime::intrinsic::get_item_in_cuda_kernel_args,
273+
{kernel_args_, ir::Expr(temp_space.arg_idx())});
274+
ir::Expr call_memset = lang::CallExtern(
275+
runtime::intrinsic::call_cuda_memset,
276+
{call_get_arg, ir::Expr(1), ir::Expr(0), size, kernel_stream_});
277+
call_kernel_stmts.push_back(call_memset);
278+
}
279+
}
280+
call_kernel_stmts.push_back(call_extern_api);
281+
call_extern_api = ir::Block::Make(call_kernel_stmts);
282+
265283
if (buckets_.empty()) {
266284
buckets_.emplace_back(ir::IfThenElse::Make(predicate, call_extern_api));
267285
} else {
@@ -270,6 +288,26 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
270288
buckets_.emplace_back(
271289
ir::IfThenElse::Make(predicate, call_extern_api, false_expr));
272290
}
291+
292+
// create infer shape calls for temp_spaces
293+
std::vector<ir::Expr> temp_space_infer_shape_stmts;
294+
for (int i = 0; i < func_node->temp_spaces.size(); ++i) {
295+
ir::Var tensor_shape_args(TENSOR_SHAPE_ARGS, type_of<int64_t **>());
296+
ir::Expr size =
297+
common::cast(func_node->temp_spaces[i].size(), common::Int(64));
298+
ir::Expr call_set_value =
299+
lang::CallExtern(runtime::intrinsic::infer_shape_set_value,
300+
{ir::Expr(func_node->num_output_tensors + i),
301+
ir::Expr(0),
302+
size,
303+
tensor_shape_args});
304+
temp_space_infer_shape_stmts.push_back(call_set_value);
305+
}
306+
if (!temp_space_infer_shape_stmts.empty()) {
307+
ir::Expr if_body = ir::Block::Make(temp_space_infer_shape_stmts);
308+
temp_space_infer_shape_body_ =
309+
ir::IfThenElse::Make(predicate, if_body, temp_space_infer_shape_body_);
310+
}
273311
}
274312

275313
void detail::CollectBucketStrategyHostFunctionVisitor::ProcessArgs(

paddle/cinn/backends/codegen_device_util.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,9 @@ struct CollectBucketStrategyHostFunctionVisitor
280280
infer_shape_func_body_stmts.insert(
281281
infer_shape_func_body_stmts.end(),
282282
op->infer_shape_func.as_lowered_func()->body);
283+
if (temp_space_infer_shape_body_.defined()) {
284+
infer_shape_func_body_stmts.push_back(temp_space_infer_shape_body_);
285+
}
283286

284287
std::vector<ir::Argument> infer_shape_arguments = {
285288
ir::Argument(kernel_args_, ir::Argument::IO::kOutput),
@@ -307,6 +310,7 @@ struct CollectBucketStrategyHostFunctionVisitor
307310
private:
308311
std::vector<ir::Expr> buckets_;
309312
std::vector<ir::Expr> arg_defs_;
313+
ir::Expr temp_space_infer_shape_body_;
310314

311315
ir::Var kernel_args_;
312316
ir::Var kernel_args_num_;

paddle/cinn/backends/codegen_invoke_module.cc

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,38 +62,6 @@ llvm::Value* CodeGenInvokeModule::LowerInvokeFunc(
6262
return f_;
6363
}
6464

65-
llvm::Value* CodeGenInvokeModule::LowerParseArgsValueCall(
66-
const ir::Call* call_ir) {
67-
auto ret_type = CinnTypeToLLVMType(Int(64), m_);
68-
std::vector<llvm::Type*> args_type;
69-
PADDLE_ENFORCE_EQ(
70-
call_ir->read_args.size(),
71-
2,
72-
::common::errors::InvalidArgument(
73-
"The number of arguments of ParseArgsValue should be 2"));
74-
PADDLE_ENFORCE_EQ(call_ir->read_args[0].is_var() &&
75-
call_ir->read_args[0].as_var()->type().is_cpp_handle(),
76-
true,
77-
::common::errors::InvalidArgument(
78-
"The first read argument must be a variable "
79-
"with a C++ handle type."));
80-
81-
PADDLE_ENFORCE_EQ(call_ir->read_args[1].type().is_int(32),
82-
true,
83-
::common::errors::InvalidArgument(
84-
"The second read argument must be of type int32."));
85-
args_type.push_back(CinnTypeToLLVMType(type_of<void*>(), m_));
86-
args_type.push_back(CinnTypeToLLVMType(type_of<int32_t>(), m_));
87-
88-
auto func_type = llvm::FunctionType::get(ret_type, args_type, false);
89-
auto call_func = m_->getOrInsertFunction(call_ir->name, func_type);
90-
91-
std::vector<llvm::Value*> call_args;
92-
call_args.push_back(std::addressof(*f_->arg_begin()));
93-
call_args.push_back(b_->getInt32(call_ir->read_args[1].as_int32()));
94-
return b_->CreateCall(call_func, call_args);
95-
}
96-
9765
llvm::Value* CodeGenSwitchHost::LowerInnerCaseCall(const ir::Call* op) {
9866
std::vector<llvm::Value*> ll_function_args;
9967
std::transform(f_->arg_begin(),

paddle/cinn/backends/codegen_invoke_module.h

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,8 @@ class CodeGenInvokeModule : public CodeGenLLVM {
4343
return LowerInvokeFunc(func);
4444
}
4545

46-
llvm::Value *Visit(const ir::Call *op) override {
47-
// TODO(Hongqing-work): change intrinsic name to get_value_in_kernel_args
48-
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
49-
return LowerParseArgsValueCall(op);
50-
} else {
51-
return CodeGenLLVM::Visit(op);
52-
}
53-
}
54-
5546
protected:
5647
llvm::Value *LowerInvokeFunc(const ir::_LoweredFunc_ *func);
57-
58-
llvm::Value *LowerParseArgsValueCall(const ir::Call *call_ir);
5948
};
6049

6150
class CodeGenHost : public CodeGenInvokeModule {
@@ -80,7 +69,7 @@ class CodeGenSwitchHost : public CodeGenInvokeModule {
8069
// only support call of args get function and inner case host function call
8170
llvm::Value *Visit(const ir::Call *op) override {
8271
if (op->name == runtime::intrinsic::get_value_in_cuda_kernel_args) {
83-
return CodeGenInvokeModule::LowerParseArgsValueCall(op);
72+
return CodeGenLLVM::Visit(op);
8473
} else {
8574
return LowerInnerCaseCall(op);
8675
}

paddle/cinn/backends/llvm/codegen_llvm.cc

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -511,53 +511,6 @@ llvm::Value *CodeGenLLVM::Visit(const ir::Cast *op) {
511511
llvm::Value *CodeGenLLVM::CreateSerialFor(const ir::For *op, int stride) {
512512
SymbolTableGuard symbol_table_guard(*symbol_table_);
513513

514-
do {
515-
break;
516-
llvm::BasicBlock *preheader_bb = b_->GetInsertBlock();
517-
auto *for_begin = llvm::BasicBlock::Create(
518-
b_->getContext(), "for_begin", b_->GetInsertBlock()->getParent());
519-
auto *for_body = llvm::BasicBlock::Create(
520-
b_->getContext(), "for_body", b_->GetInsertBlock()->getParent());
521-
auto *for_end = llvm::BasicBlock::Create(
522-
b_->getContext(), "for_end", b_->GetInsertBlock()->getParent());
523-
524-
Br(for_begin);
525-
b_->SetInsertPoint(for_begin);
526-
527-
auto *begin = Visit(&op->min);
528-
auto *loop_value = PHI(begin->getType(), 2);
529-
loop_value->addIncoming(begin, preheader_bb);
530-
531-
llvm::Value *old_var = GetVar(op->loop_var->name);
532-
SetVar(op->loop_var->name, loop_value);
533-
auto *end = Visit(&op->extent);
534-
CondBr(ICmpSLT(loop_value, end), for_body, for_end);
535-
b_->SetInsertPoint(for_body);
536-
Visit(&op->body);
537-
538-
if (old_var) {
539-
SetVar(op->loop_var->name, old_var);
540-
} else {
541-
symbol_table_->Erase(op->loop_var->name);
542-
}
543-
544-
auto loop_next = Add(loop_value,
545-
llvm::ConstantInt::get(b_->getInt32Ty(), stride),
546-
"indvar.inc",
547-
true,
548-
true);
549-
loop_value->addIncoming(loop_next, b_->GetInsertBlock());
550-
551-
Br(for_begin);
552-
b_->SetInsertPoint(for_end);
553-
554-
return nullptr;
555-
// llvm::AllocaInst *loop_var = Alloca(b_->getInt32Ty(), nullptr,
556-
// op->loop_var->name); loop_var->setAlignment(llvm::Align(4));
557-
// SetVar(op->loop_var->name, loop_var);
558-
} while (false);
559-
560-
////////////////////////////////////
561514
llvm::BasicBlock *preheader_bb = b_->GetInsertBlock();
562515
llvm::BasicBlock *exit_bb = nullptr;
563516

@@ -814,20 +767,13 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Module_ *op) {
814767
}
815768

816769
llvm::Value *CodeGenLLVM::Visit(const ir::_Var_ *op) {
817-
llvm::Value *value = GetVar(op->name, false);
818-
llvm::Value *result{};
819-
CHECK(value) << "ir::_Var_[" << op->name << "]: value is null";
820-
// TODO(fc500110) hard coding
821-
if (LLVM_WillVarLowerAsPointer(op->name)) {
822-
result = value;
823-
} else if (value->getType()->isPointerTy() &&
824-
!value->getType()->getPointerElementType()->isPointerTy()) {
825-
result = Load(value, op->name + "_load");
826-
} else {
827-
result = value;
770+
llvm::Value *value = GetVar(op->name, /* lazy= */ false);
771+
// When visiting a Var that is allocated on the stack, we are actually
772+
// reading its value instead of its address.
773+
if (llvm::AllocaInst::classof(value)) {
774+
return Load(value, op->name + "_load");
828775
}
829-
830-
return result;
776+
return value;
831777
}
832778

833779
void CodeGenLLVM::Scalarize(
@@ -1043,12 +989,6 @@ llvm::Value *CodeGenLLVM::Visit(const ir::_Buffer_ *op) {
1043989

1044990
llvm::Value *CodeGenLLVM::Visit(const ir::_Tensor_ *op) {
1045991
return GetVar(op->name);
1046-
auto *buffer_op = op->buffer.As<ir::_Buffer_>();
1047-
if (symbol_table_->Lookup(buffer_op->name)) {
1048-
return Visit(buffer_op);
1049-
}
1050-
1051-
return SetVar(buffer_op->name, Visit(buffer_op));
1052992
}
1053993

1054994
template <typename T,
@@ -1437,10 +1377,6 @@ void CodeGenLLVM::InitTarget(const Target &target) {
14371377
naive_vec_alignment_ = GetNaiveVecAlignment(target);
14381378
}
14391379

1440-
bool LLVM_WillVarLowerAsPointer(const std::string &var_name) {
1441-
return var_name == "_args" || utils::EndsWith(var_name, "__ptr");
1442-
}
1443-
14441380
void CodeGenLLVM::AddTbaaMetadata(llvm::Instruction *inst,
14451381
absl::string_view buffer,
14461382
Expr index) {

paddle/cinn/backends/llvm/codegen_llvm.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,6 @@ class LLVMIRVisitor : public ir::IRVisitorRequireReImpl<llvm::Value *> {
4949
#undef __m
5050
};
5151

52-
/**
53-
* Tell whether a variable called \p \var_name will lowered to a pointer type in
54-
* LLVM.
55-
* @param var_name name of the variable.
56-
* @return a boolean.
57-
*/
58-
bool LLVM_WillVarLowerAsPointer(const std::string &var_name);
59-
6052
class SymbolTable {
6153
public:
6254
SymbolTable() = default;

paddle/cinn/common/const_fold.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,41 @@ inline std::optional<ir::Expr> TryConstFold<ir::Mul>(ir::Expr a, ir::Expr b) {
7171
return std::nullopt;
7272
}
7373

74+
template <>
75+
inline std::optional<ir::Expr> TryConstFold<ir::Div>(ir::Expr a, ir::Expr b) {
76+
const ir::IntImm* pa = a.As<ir::IntImm>();
77+
const ir::IntImm* pb = b.As<ir::IntImm>();
78+
const auto& rtype = a.type();
79+
if (pa && pb) {
80+
int64_t res = pa->value / pb->value;
81+
return cinn::common::make_shared<ir::IntImm>(rtype, res);
82+
}
83+
if (pa) {
84+
if (pa->value == 0) return a;
85+
}
86+
if (pb) {
87+
if (pb->value == 1) return a;
88+
}
89+
return std::nullopt;
90+
}
91+
92+
template <>
93+
inline std::optional<ir::Expr> TryConstFold<ir::Mod>(ir::Expr a, ir::Expr b) {
94+
const ir::IntImm* pa = a.As<ir::IntImm>();
95+
const ir::IntImm* pb = b.As<ir::IntImm>();
96+
const auto& rtype = a.type();
97+
if (pa && pb) {
98+
int64_t res = pa->value % pb->value;
99+
return cinn::common::make_shared<ir::IntImm>(rtype, res);
100+
}
101+
if (pa) {
102+
if (pa->value == 0) return a;
103+
}
104+
if (pb) {
105+
if (pb->value == 1) return ir::Zero(rtype);
106+
}
107+
return std::nullopt;
108+
}
109+
74110
} // namespace common
75111
} // namespace cinn

0 commit comments

Comments
 (0)