Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions paddle/cinn/backends/function_prototype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,18 @@ void FunctionProto::AssertMatch(const ir::Call *op) const {

void FunctionProto::CheckValid() {
if (ret_type.is_void()) {
CHECK(!mutable_arg_types.empty())
<< "A void function should have at least one mutable argument to "
"output something";
PADDLE_ENFORCE_EQ(
!mutable_arg_types.empty(),
true,
phi::errors::InvalidArgument(
"A void function should have at least one mutable argument to "
"output something."));
} else {
CHECK(mutable_arg_types.empty())
<< "A function with return should not have mutable argument";
PADDLE_ENFORCE_EQ(
mutable_arg_types.empty(),
true,
phi::errors::InvalidArgument(
"A function with return should not have mutable argument."));
}
}

Expand All @@ -107,7 +113,10 @@ FunctionProto::shape_inference_t FunctionProto::ShapeFollowNthArgument(int n) {
::common::errors::InvalidArgument(
"The argument index is out of range"));
auto x = args[n].as_tensor();
CHECK(x);
PADDLE_ENFORCE_NOT_NULL(
x,
phi::errors::InvalidArgument(
"The argument at index (%d) must be a tensor.", n));
return x->shape;
};
}
Expand Down
14 changes: 10 additions & 4 deletions paddle/cinn/backends/llvm/simple_jit.cc
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ void SimpleJIT::AddModule(std::unique_ptr<llvm::Module> module, bool optimize) {
LOG(INFO) << "fn:\n" << DumpToString(fn);
}
*/
CHECK(!llvm::verifyModule(*module, &llvm::errs()))
<< "Transformation resulted in an invalid module\n\nmodule:\n";
PADDLE_ENFORCE_EQ(
!llvm::verifyModule(*module, &llvm::errs()),
true,
phi::errors::InvalidArgument(
"Transformation resulted in an invalid module\n\nmodule:\n"));

bool debug = false;
if (optimize) {
Expand Down Expand Up @@ -99,7 +102,8 @@ SimpleJIT::SimpleJIT() : context_(std::make_unique<llvm::LLVMContext>()) {
llvm::InitializeAllAsmPrinters();

jit_ = llvm::cantFail(llvm::orc::LLJITBuilder().create());
CHECK(jit_) << "JIT create failed";
PADDLE_ENFORCE_NOT_NULL(jit_,
phi::errors::InvalidArgument("JIT creation failed."));

auto proc_symbols_generator = llvm::cantFail(
llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess(
Expand Down Expand Up @@ -129,7 +133,9 @@ void SimpleJIT::Link(ir::Module module, bool optimize) {
auto ir_emitter = std::make_unique<CodeGenT>(m.get(), b.get());
ir_emitter->Compile(module);

CHECK(!llvm::verifyModule(*m, &llvm::errs())) << "Invalid module found";
PADDLE_ENFORCE_EQ(!llvm::verifyModule(*m, &llvm::errs()),
true,
phi::errors::InvalidArgument("Invalid module found."));

AddModule(std::move(m), optimize);
}
Expand Down
19 changes: 14 additions & 5 deletions paddle/cinn/hlir/pe/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ using lang::Compute;
void GetRealAxes(int ndim,
const std::vector<int>& axes,
std::vector<int>* real_axes) {
CHECK(real_axes);
PADDLE_ENFORCE_NOT_NULL(real_axes,
phi::errors::InvalidArgument(
"The 'real_axes' pointer must not be null."));
if (axes.empty()) {
for (int i = 0; i < ndim; ++i) {
real_axes->push_back(i);
Expand Down Expand Up @@ -120,7 +122,9 @@ void GetOutputShape(const std::vector<int>& real_axes,
std::vector<Expr>* output_shape,
const Tensor& tensor,
bool keep_dims) {
CHECK(output_shape);
PADDLE_ENFORCE_NOT_NULL(output_shape,
phi::errors::InvalidArgument(
"The 'output_shape' pointer must not be null."));
auto ndim = tensor->shape.size();
if (keep_dims) {
for (size_t i = 0; i < ndim; ++i) {
Expand All @@ -141,7 +145,10 @@ void GetOutputShape(const std::vector<int>& real_axes,
output_shape->push_back(cinn::common::make_one());
}

CHECK(!tensor->shape.empty());
PADDLE_ENFORCE_EQ(
!tensor->shape.empty(),
true,
phi::errors::InvalidArgument("The 'tensor' shape must not be empty."));
if (tensor->shape[0]->type() == Int(64)) {
for (auto& shape_item : *output_shape) {
shape_item->convert_int32_to_int64();
Expand Down Expand Up @@ -868,8 +875,10 @@ std::vector<ir::Tensor> TwoStepBlockReduceInternal(
ReduceFunc reduce_func,
BlockReduceFunc block_reduce_func,
ir::Expr initial) {
CHECK(!WithoutLastDimInReduce(A->shape, axes))
<< "Can't find last axis in reduce!";
PADDLE_ENFORCE_EQ(
!WithoutLastDimInReduce(A->shape, axes),
true,
phi::errors::InvalidArgument("Can't find last axis in reduce!"));
// If the number of current device SM is smaller than the number of SM
// required by Warp Reduce, the performance of Warp Reduce is better.
// Otherwise, use Block Reduce.
Expand Down
24 changes: 19 additions & 5 deletions paddle/cinn/hlir/pe/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,21 @@ void GetConv2dFactors(absl::flat_hash_map<std::string, int> *factors,
auto &params = ScheduleParam::get_x86_instance().GetParam();
if (params.count(key)) {
VLOG(3) << "find saved param, key is: " << key;
CHECK(!params[key]["oc_bn"].empty());
CHECK(!params[key]["ic_bn"].empty());
CHECK(!params[key]["ow_bn"].empty());
PADDLE_ENFORCE_EQ(
!params[key]["oc_bn"].empty(),
true,
phi::errors::InvalidArgument(
"The parameter 'oc_bn' for key '%s' must not be empty.", key));
PADDLE_ENFORCE_EQ(
!params[key]["ic_bn"].empty(),
true,
phi::errors::InvalidArgument(
"The parameter 'ic_bn' for key '%s' must not be empty.", key));
PADDLE_ENFORCE_EQ(
!params[key]["ow_bn"].empty(),
true,
phi::errors::InvalidArgument(
"The parameter 'ow_bn' for key '%s' must not be empty.", key));
(*factors)["oc_bn"] = params[key]["oc_bn"].back();
(*factors)["ic_bn"] = params[key]["ic_bn"].back();
(*factors)["ow_bn"] = params[key]["ow_bn"].back();
Expand Down Expand Up @@ -495,8 +507,10 @@ inline void InputDirectConvCudaParam(
schedule_data["f"] = int_data[3];
schedule_data["y"] = int_data[4];
schedule_data["x"] = int_data[5];
CHECK(model_data.count(key) == 0)
<< "Key " << key << "in conv cuda param already exists.";
PADDLE_ENFORCE_EQ(model_data.count(key),
0,
phi::errors::InvalidArgument(
"Key '%s' in conv CUDA param already exists.", key));
model_data[key] = schedule_data;
}

Expand Down
25 changes: 19 additions & 6 deletions paddle/cinn/pybind/ir/ir_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,10 @@ void ElseContextNode::ExitWithContext() {
}

Expr IRBuilderNode::GetResult() const {
CHECK(result.defined()) << "No result generated in IRBuilder";
PADDLE_ENFORCE_EQ(
result.defined(),
true,
phi::errors::InvalidArgument("No result generated in IRBuilder."));
return result;
}

Expand All @@ -100,22 +103,32 @@ IRBuilder::IRBuilder() {
}

void IRBuilder::EnterWithContext() {
CHECK(data_->contexts.empty())
<< "There are still Contexts in IRBuilder that has not been fully "
"converted. Please build a new IR with the new IRbuilder";
PADDLE_ENFORCE_EQ(
data_->contexts.empty(),
true,
phi::errors::InvalidArgument(
"There are still contexts in IRBuilder that have not been fully "
"converted. Please build a new IR with the new IRBuilder."));

data_->result.Reset();
std::vector<IRBuilder>* st = IRBuilderStack();
st->push_back(*this);
}

void IRBuilder::ExitWithContext() {
std::vector<IRBuilder>* st = IRBuilderStack();
CHECK(!st->empty());
PADDLE_ENFORCE_EQ(
!st->empty(),
true,
phi::errors::InvalidArgument("The IRBuilder stack must not be empty."));
st->pop_back();
}
IRBuilder IRBuilder::CurrentIRBuilder() {
std::vector<IRBuilder>* st = IRBuilderStack();
CHECK(!st->empty()) << "No IRBuilder Found";
PADDLE_ENFORCE_EQ(
!st->empty(),
true,
phi::errors::InvalidArgument("No IRBuilder found in the stack."));
return st->back();
}
std::vector<IRBuilder>* IRBuilderStack() {
Expand Down