Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions paddle/cinn/adt/simplify_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ struct SimplifyDotUndot {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
PADDLE_ENFORCE_EQ(
pre_index_undot.has_value(),
true,
phi::errors::InvalidArgument("pre_index_undot should not be null"));
const auto& [index_value, undot_dims] =
pre_index_undot.value()
.Get<IndexUnDotValue<Value, List<DimExpr>>>()
Expand Down Expand Up @@ -195,9 +198,14 @@ struct SimplifyGcdShape {
const auto& iter_values = index_dot_values.Get<List<Value>>();
const auto& undot_dim_values = undot_dims;
const auto& dot_dim_values = dot_dims;
CHECK(IsConstantListAllPositiveInt64(undot_dim_values));
CHECK(IsConstantListAllPositiveInt64(dot_dim_values));

PADDLE_ENFORCE_EQ(IsConstantListAllPositiveInt64(undot_dim_values),
true,
phi::errors::InvalidArgument(
"The undot_dim_values should be all positive int64"));
PADDLE_ENFORCE_EQ(IsConstantListAllPositiveInt64(dot_dim_values),
true,
phi::errors::InvalidArgument(
"The dot_dim_values should be all positive int64"));
const auto& sub_reshape_dim_ranges =
GetSubReshapeDimRanges(undot_dim_values, dot_dim_values);
if (!sub_reshape_dim_ranges.has_value()) {
Expand Down Expand Up @@ -321,7 +329,10 @@ struct SimplifyDotDot {
std::int64_t Product(const List<DimExpr>& dims) {
std::int64_t ret = 1;
for (const auto& dim : *dims) {
CHECK(dim.Has<std::int64_t>());
PADDLE_ENFORCE_EQ(
dim.Has<std::int64_t>(),
true,
phi::errors::InvalidArgument("dim should have std::int64_t"));
ret *= dim.Get<std::int64_t>();
}
return ret;
Expand Down Expand Up @@ -400,7 +411,10 @@ struct SymbolicDim_SimplifyDotUndot {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
PADDLE_ENFORCE_EQ(
pre_index_undot.has_value(),
true,
phi::errors::InvalidArgument("pre_index_undot should not be null"));
const auto& [index_value, undot_dims] =
pre_index_undot.value()
.Get<IndexUnDotValue<Value, List<DimExpr>>>()
Expand Down Expand Up @@ -447,7 +461,10 @@ struct SymbolicDim_SimplifyDotUndot_DimExpr {
pre_index_undot = index_undot_value;
}
}
CHECK(pre_index_undot.has_value());
PADDLE_ENFORCE_EQ(
pre_index_undot.has_value(),
true,
phi::errors::InvalidArgument("pre_index_undot should not be null"));
const auto& [index_value, undot_dims] =
pre_index_undot.value()
.Get<IndexUnDotValue<Value, List<DimExpr>>>()
Expand Down
20 changes: 13 additions & 7 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,10 @@ void Compiler::RegisterCudaModuleSymbol() {
std::string source_code =
CodeGenCUDA_Dev::GetSourceHeader() + device_fn_code_;
auto ptx = compiler(source_code);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n"
<< source_code;
PADDLE_ENFORCE_EQ(
!ptx.empty(),
true,
phi::errors::InvalidArgument("Compile PTX failed from source code\n"));
using runtime::cuda::CUDAModule;
cuda_module_.reset(new CUDAModule(ptx,
compiler.compile_to_cubin()
Expand All @@ -328,7 +330,9 @@ void Compiler::RegisterCudaModuleSymbol() {
RuntimeSymbols symbols;
for (const auto& kernel_fn_name : device_fn_name_) {
auto fn_kernel = cuda_module_->GetFunction(kernel_fn_name);
CHECK(fn_kernel) << "Fail to get CUfunction kernel_fn_name";
PADDLE_ENFORCE_NOT_NULL(
fn_kernel,
phi::errors::InvalidArgument("Fail to get CUfunction kernel_fn_name"));
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));
symbols.RegisterVar(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(fn_kernel));
Expand Down Expand Up @@ -361,9 +365,10 @@ void Compiler::CompileCudaModule(const Module& module,
source_code = code;
}

CHECK(!source_code.empty())
<< "Compile CUDA C code failed from device module:\n"
<< device_module;
PADDLE_ENFORCE_EQ(!source_code.empty(),
true,
phi::errors::InvalidArgument(
"Compile CUDA C code failed from device module"));
VLOG(3) << "[CUDA] C:\n" << source_code;
SourceCodePrint::GetInstance()->write(source_code);
device_fn_code_ += source_code;
Expand Down Expand Up @@ -393,7 +398,8 @@ void Compiler::ExportObject(const std::string& path) {
}

void* Compiler::Lookup(absl::string_view fn_name) {
CHECK(engine_);
PADDLE_ENFORCE_NOT_NULL(
engine_, phi::errors::InvalidArgument("Sorry, engine_ is nullptr"));
if (engine_->Lookup(fn_name) != nullptr) {
return engine_->Lookup(fn_name);
}
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/backends/extern_func_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ void ExternFunctionEmitterRegistry::Register(const ExternFuncID& name,
utils::GetStreamCnt(name).c_str());
}
#endif // CINN_WITH_DEBUG
CHECK(!x.empty()) << "Extern Function name is empty.";
PADDLE_ENFORCE_EQ(
!x.empty(),
true,
phi::errors::InvalidArgument("Extern Function name is empty."));
data_[name] = x;
}

Expand All @@ -68,7 +71,10 @@ ExternFunctionEmitterRegistry::ExternFunctionEmitterRegistry() {}

const FunctionProto& ExternFunctionEmitter::func_proto() const {
auto* proto = ExternFunctionProtoRegistry::Global().Lookup(func_name());
CHECK(proto) << "No prototype of function [" << func_name() << "]";
PADDLE_ENFORCE_NOT_NULL(
proto,
phi::errors::InvalidArgument("No prototype of function [" +
std::string(func_name()) + "]"));
return *proto;
}

Expand Down
11 changes: 7 additions & 4 deletions paddle/cinn/backends/llvm/execution_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,18 @@ void ExecutionEngine::Link(const ir::Module &module) {
VLOG(3) << "ir_emitter->Compile(module) Begin";
ir_emitter->Compile(module);
VLOG(3) << "ir_emitter->Compile(module) Succeed!";
CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";

PADDLE_ENFORCE_EQ(!llvm::verifyModule(*m, &llvm::errs()),
true,
phi::errors::InvalidArgument("Sorry,Invalid module found"));
auto machine = std::move(llvm::cantFail(
llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost())
.createTargetMachine()));
LLVMModuleOptimizer optimize(machine.get(), 3, {}, true);
optimize(m.get());
CHECK(!llvm::verifyModule(*m, &llvm::errs()))
<< "Invalid optimized module detected";
PADDLE_ENFORCE_EQ(
!llvm::verifyModule(*m, &llvm::errs()),
true,
phi::errors::InvalidArgument("Invalid optimized module detected"));
for (auto &f : *m) {
VLOG(5) << "function: " << DumpToString(f);
}
Expand Down
34 changes: 23 additions & 11 deletions paddle/cinn/common/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,19 @@ GiNaC::ex ExprToGinacConverter::operator()(Expr expr) {
n->As<IfThenElse>();
});

CHECK(complex_nodes.empty()) << "Ginac converter can only deal with simple "
"math expression, but get some complex nodes"
<< expr;

PADDLE_ENFORCE_EQ(complex_nodes.empty(),
true,
::common::errors::InvalidArgument(
"Ginac converter can only deal with simple math "
"expression, but get some complex nodes."));
return BuildHelper(expr);
}

GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const std::string& repr) {
CHECK(!repr.empty());
PADDLE_ENFORCE_EQ(
!repr.empty(),
true,
::common::errors::InvalidArgument("The repr should not be empty."));
auto it = repr_to_ginac_.find(repr);
if (it != repr_to_ginac_.end()) return it->second;

Expand All @@ -165,7 +169,9 @@ GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const std::string& repr) {
}

GiNaC::symbol ExprToGinacConverter::CreateGinacSymbol(const ir::Expr& var) {
CHECK(var.As<_Var_>());
PADDLE_ENFORCE_NOT_NULL(
var.As<_Var_>(),
::common::errors::InvalidArgument("The var should not be nullptr."));
return CreateGinacSymbol(Repr(var));
}

Expand All @@ -191,8 +197,10 @@ class GiNaCToExprVisitor : public GiNaC::symbol::visitor,

void visit(const GiNaC::symbol& node) override {
auto it = repr_to_expr.find(node.get_name());
CHECK(it != repr_to_expr.end())
<< "node [" << node.get_name() << "] not found";
PADDLE_ENFORCE_NE(
it,
repr_to_expr.end(),
::common::errors::InvalidArgument("The node should be found."));
cur = it->second;
}

Expand Down Expand Up @@ -221,7 +229,9 @@ class GiNaCToExprVisitor : public GiNaC::symbol::visitor,
node.op(1).accept(*this);

auto* intv = cur.As<IntImm>();
CHECK(intv);
PADDLE_ENFORCE_NOT_NULL(
intv,
::common::errors::InvalidArgument("The intv should not be nullptr."));
PADDLE_ENFORCE_EQ(
intv->value,
-1,
Expand Down Expand Up @@ -313,8 +323,10 @@ std::tuple<Expr, bool /*positive*/> Solve(Expr lhs, Expr rhs, Var var) {
// tell the symbol
auto diff = lhs_ex - rhs_ex;
auto diff_res = ginac::diff(diff, symbol);
CHECK(!diff_res.is_zero());

PADDLE_ENFORCE_EQ(
!diff_res.is_zero(),
true,
::common::errors::InvalidArgument("The diff_res should not be zero."));
return std::make_tuple(value, diff_res > 0);
}

Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/common/dev_info_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,17 @@ class DevInfoMgr final {
using RetType = typename GetDevType<arch>::DevType;

const RetType* operator->() const {
CHECK(!std::is_void<RetType>()) << "Current device can't be recognized!\n";
PADDLE_ENFORCE_EQ(
!std::is_void<RetType>(),
true,
phi::errors::InvalidArgument("Current device can't be recognized!"));
return dynamic_cast<const RetType*>(impl_.get());
}
RetType* operator->() {
CHECK(!std::is_void<RetType>()) << "Current device can't be recognized!\n";
PADDLE_ENFORCE_EQ(
!std::is_void<RetType>(),
true,
phi::errors::InvalidArgument("Current device can't be recognized!"));
return dynamic_cast<RetType*>(impl_.get());
}
};
Expand Down
11 changes: 8 additions & 3 deletions paddle/cinn/hlir/op/contrib/bitcast_convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,21 @@ std::shared_ptr<framework::OpStrategy> StrategyForBitcastConvert(

framework::CINNCompute bitcast_convert_compute([=](lang::Args args,
lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of " << op_name
<< " compute is empty! Please check.";
PADDLE_ENFORCE_EQ(
!args.empty(),
true,
phi::errors::InvalidArgument(
"The input argument of %s compute is empty!", op_name));
CINNValuePack pack_args = args[0];
PADDLE_ENFORCE_GE(pack_args.size(),
1U,
phi::errors::InvalidArgument(
"The size of pack_args should be greater than 0 . "));
std::string tensor_name = UniqName(op_name + "_Out");
Expr A_expr = pack_args[0];
CHECK(A_expr.as_tensor());
PADDLE_ENFORCE_NOT_NULL(
A_expr.as_tensor(),
phi::errors::InvalidArgument("The input argument A is not a tensor."));
ir::Tensor A = A_expr.as_tensor_ref();
auto out = BitcastConvert(A, out_type[0], tensor_name);
std::vector<CINNValue> res;
Expand Down
15 changes: 9 additions & 6 deletions paddle/cinn/hlir/op/op_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ namespace hlir {
template <typename T>
T GetAttr(const cinn::utils::AttributeMap &attr_map,
const std::string &attr_name) {
CHECK(attr_map.count(attr_name))
<< "Cannot found attribute \"" << attr_name << "\"";
PADDLE_ENFORCE_EQ(attr_map.count(attr_name),
true,
phi::errors::InvalidArgument(
"Sorry, cannot found attribute %s", attr_name));
const auto &attr = attr_map.at(attr_name);

CHECK(absl::holds_alternative<T>(attr))
<< "The type of attribute \"" << attr_name << "\" isn't "
<< typeid(T).name();
PADDLE_ENFORCE_EQ(
absl::holds_alternative<T>(attr),
true,
phi::errors::InvalidArgument(
"The type of attribute %s isn't %s", attr_name, typeid(T).name()));
return absl::get<T>(attr_map.at(attr_name));
}

Expand Down
Loading