Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion paddle/cinn/backends/codegen_device_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
#ifdef CINN_WITH_HIP
#include "paddle/cinn/backends/hip/codegen_hip_dev.h"
#endif
#ifdef CINN_WITH_SYCL
#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h"
#endif
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
Expand Down Expand Up @@ -140,7 +143,14 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
#endif
},
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED });
[&](common::HygonDCUArchSYCL) {
#ifdef CINN_WITH_SYCL
sycl::CodeGenSyclDevice codegen_dev(
cinn::common::DefaultHygonDcuSyclTarget());
codegen_dev.Compile(ir::LoweredFunc(func));
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
#endif
});

VLOG(6) << "Add a call node for func->name " << func->name << "\n"
<< "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", "
Expand Down
84 changes: 82 additions & 2 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
#include "paddle/cinn/backends/hip/compiler_hip.h"
#include "paddle/cinn/runtime/hip/hip_module.h"
#endif
#ifdef CINN_WITH_SYCL
#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h"
#include "paddle/cinn/backends/sycl/compiler_sycl.h"
#include "paddle/cinn/runtime/sycl/sycl_module.h"
#endif
#include "paddle/cinn/adt/adt.h"

PD_DECLARE_string(cinn_source_code_save_path);
Expand Down Expand Up @@ -288,7 +293,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
CINN_NOT_IMPLEMENTED
#endif
},
[&](common::HygonDCUArchSYCL) -> std::string { CINN_NOT_IMPLEMENTED });
[&](common::HygonDCUArchSYCL) -> std::string {
#ifdef CINN_WITH_SYCL
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
sycl::CodeGenSyclDevice codegen(target_);
auto source_code = codegen.Compile(device_module);
return source_code;
#else
CINN_NOT_IMPLEMENTED
#endif
});
}

void Compiler::BuildDefault(const Module& module) {
Expand Down Expand Up @@ -393,7 +410,38 @@ void Compiler::RegisterHipModuleSymbol() {
#endif
}

void Compiler::RegisterSyclModuleSymbol() { CINN_NOT_IMPLEMENTED }
void Compiler::RegisterSyclModuleSymbol() {
#ifdef CINN_WITH_SYCL
syclrtc::Compiler compiler;
std::string source_code =
sycl::CodeGenSyclDevice::GetSourceHeader() + device_fn_code_;
std::string hsaco = compiler(source_code);
PADDLE_ENFORCE_EQ(
!hsaco.empty(),
true,
::common::errors::Fatal("Compile hsaco failed from source code:\n%s",
source_code));
using runtime::sycl::SYCLModule;
sycl_module_.reset(new SYCLModule(source_code, hsaco, SYCLModule::Kind::so));
// get device id
using cinn::runtime::BackendAPI;
int device_id = BackendAPI::get_backend(target_)->get_device();
// register kernel
RuntimeSymbols symbols;
for (const auto& kernel_fn_name : device_fn_name_) {
auto fn_kernel = sycl_module_->GetFunction(kernel_fn_name);
PADDLE_ENFORCE_NOT_NULL(
fn_kernel,
::common::errors::Fatal("HIP GetFunction Error: get valid kernel."));
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));
symbols.RegisterVar(kernel_fn_name + "_ptr_",
reinterpret_cast<void*>(fn_kernel));
}
engine_->RegisterModuleRuntimeSymbols(std::move(symbols));
#else
CINN_NOT_IMPLEMENTED
#endif
}

void Compiler::CompileCudaModule(const Module& module,
const std::string& code) {
Expand Down Expand Up @@ -473,7 +521,39 @@ void Compiler::CompileHipModule(const Module& module, const std::string& code) {

void Compiler::CompileSyclModule(const Module& module,
const std::string& code) {
#ifdef CINN_WITH_SYCL
auto _host_module_device_module_ =
SplitDeviceAndHostModule(module); // NOLINT
auto& host_module = std::get<0>(_host_module_device_module_);
auto& device_module = std::get<1>(_host_module_device_module_);
VLOG(3) << "[SYCL] host module:\n" << host_module;
VLOG(3) << "[SYCL] device module:\n" << device_module;
std::string source_code;
if (!FLAGS_cinn_debug_custom_code_path.empty()) {
std::string file_path = FLAGS_cinn_debug_custom_code_path;
source_code = GetFileContent(file_path);
} else if (code.empty()) {
sycl::CodeGenSyclDevice codegen(target_);
source_code = codegen.Compile(device_module);
} else {
source_code = code;
}
PADDLE_ENFORCE_EQ(
!source_code.empty(),
true,
::common::errors::Fatal(
"Compile SYCL code failed from device module:\n%s", device_module));
VLOG(3) << "[SYCL]:\n" << source_code;
SourceCodePrint::GetInstance()->write(source_code);
device_fn_code_ += source_code;
for (auto& fn : device_module.functions()) {
std::string kernel_fn_name = fn->name;
device_fn_name_.emplace_back(kernel_fn_name);
}
engine_->Link<CodeGenGpuHost>(host_module);
#else
CINN_NOT_IMPLEMENTED
#endif
}

void Compiler::CompileX86Module(const Module& module) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#ifdef CINN_WITH_HIP
#include "paddle/cinn/runtime/hip/hip_module.h"
#endif
#ifdef CINN_WITH_SYCL
#include "paddle/cinn/runtime/sycl/sycl_module.h"
#endif

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -167,6 +170,9 @@ class Compiler final {
#ifdef CINN_WITH_HIP
std::unique_ptr<runtime::hip::HIPModule> hip_module_;
#endif
#ifdef CINN_WITH_SYCL
std::unique_ptr<runtime::sycl::SYCLModule> sycl_module_;
#endif
};

} // namespace backends
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/common/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,10 @@ const Target &DefaultHygonDcuSyclTarget() {
const Target &DefaultDeviceTarget() {
#ifdef CINN_WITH_CUDA
return DefaultNVGPUTarget();
#elif defined(CINN_WITH_HIP)
return DefaultHygonDcuHipTarget();
#elif defined(CINN_WITH_SYCL)
return DefaultHygonDcuSyclTarget();
#elif defined(CINN_WITH_HIP)
return DefaultHygonDcuHipTarget();
#endif
}

Expand Down Expand Up @@ -400,10 +400,10 @@ int GetMaxBlocks() {
const Target &DefaultTarget() {
#ifdef CINN_WITH_CUDA
return DefaultNVGPUTarget();
#elif defined(CINN_WITH_HIP)
return DefaultHygonDcuHipTarget();
#elif defined(CINN_WITH_SYCL)
return DefaultHygonDcuSyclTarget();
#elif defined(CINN_WITH_HIP)
return DefaultHygonDcuHipTarget();
#else
return DefaultHostTarget();
#endif
Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
VLOG(10) << "After Optimize TransBufferWithDynamicShape:" << copied;
#endif
},
[&](common::HygonDCUArchHIP) {
[&](std::variant<common::HygonDCUArchHIP, common::HygonDCUArchSYCL>) {
#ifdef CINN_WITH_HIP
ir::SetCudaAxisInfo(copied);
if (remove_gpu_for_loops) {
Expand All @@ -120,7 +120,6 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied;
#endif
},
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED },
[](auto) {});

SimplifyBlocks(&copied->body);
Expand Down
14 changes: 13 additions & 1 deletion paddle/cinn/optim/trans_buffer_with_dynamic_shape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,19 @@ LogicalResult TransBufferWithDynamicShapePass::Run(ir::LoweredFunc func) {
"The shared memory size used by current kernel is greater "
"than the max shared memory per block"));
},
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED });
[&](common::HygonDCUArchSYCL) {
using cinn::runtime::BackendAPI;
size_t max_shm_per_block =
BackendAPI::get_backend(common::HygonDCUArchSYCL{})
->get_device_property(
BackendAPI::DeviceProperty::MaxSharedMemoryPerBlock);
PADDLE_ENFORCE_LE(
mutator.shared_mem_size_used(),
max_shm_per_block,
::common::errors::InvalidArgument(
"The shared memory size used by current kernel is greater "
"than the max shared memory per block"));
});
return LogicalResult::success();
}

Expand Down
8 changes: 5 additions & 3 deletions paddle/cinn/runtime/sycl/sycl_backend_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ void SYCLBackendAPI::Init(Arch arch) {
backend = ::sycl::backend::ext_oneapi_hip;
});
// look for matched devices
for (auto device : devices) {
if (device.get_backend() == backend) {
this->devices.push_back(device);
if (this->devices.size() < 8) {
for (auto device : devices) {
if (device.get_backend() == backend) {
this->devices.push_back(device);
}
}
}
if (this->devices.size() == 0) {
Expand Down
5 changes: 5 additions & 0 deletions paddle/cinn/runtime/use_extern_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/use_extern_funcs.h"
#endif

#ifdef CINN_WITH_HIP
#include "paddle/cinn/runtime/hip/use_extern_funcs.h"
#endif

#ifdef CINN_WITH_SYCL
#include "paddle/cinn/runtime/sycl/use_extern_funcs.h"
#endif