Skip to content

Commit 5879c62

Browse files
authored
[CINN][new hardware] SYCL third PR: complete the SYCL logic (#71204)
* complete sycl logic * fix_bugs * fix_bugs
1 parent e29a10c commit 5879c62

File tree

8 files changed

+127
-13
lines changed

8 files changed

+127
-13
lines changed

paddle/cinn/backends/codegen_device_util.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
#ifdef CINN_WITH_HIP
2626
#include "paddle/cinn/backends/hip/codegen_hip_dev.h"
2727
#endif
28+
#ifdef CINN_WITH_SYCL
29+
#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h"
30+
#endif
2831
#include "paddle/cinn/cinn.h"
2932
#include "paddle/cinn/ir/ir.h"
3033
#include "paddle/cinn/ir/ir_mutator.h"
@@ -140,7 +143,14 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
140143
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
141144
#endif
142145
},
143-
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED });
146+
[&](common::HygonDCUArchSYCL) {
147+
#ifdef CINN_WITH_SYCL
148+
sycl::CodeGenSyclDevice codegen_dev(
149+
cinn::common::DefaultHygonDcuSyclTarget());
150+
codegen_dev.Compile(ir::LoweredFunc(func));
151+
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
152+
#endif
153+
});
144154

145155
VLOG(6) << "Add a call node for func->name " << func->name << "\n"
146156
<< "grid_dim: (" << func->cuda_axis_info.grid_dim(0) << ", "

paddle/cinn/backends/compiler.cc

Lines changed: 82 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@
3535
#include "paddle/cinn/backends/hip/compiler_hip.h"
3636
#include "paddle/cinn/runtime/hip/hip_module.h"
3737
#endif
38+
#ifdef CINN_WITH_SYCL
39+
#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h"
40+
#include "paddle/cinn/backends/sycl/compiler_sycl.h"
41+
#include "paddle/cinn/runtime/sycl/sycl_module.h"
42+
#endif
3843
#include "paddle/cinn/adt/adt.h"
3944

4045
PD_DECLARE_string(cinn_source_code_save_path);
@@ -288,7 +293,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
288293
CINN_NOT_IMPLEMENTED
289294
#endif
290295
},
291-
[&](common::HygonDCUArchSYCL) -> std::string { CINN_NOT_IMPLEMENTED });
296+
[&](common::HygonDCUArchSYCL) -> std::string {
297+
#ifdef CINN_WITH_SYCL
298+
auto _host_module_device_module_ =
299+
SplitDeviceAndHostModule(module); // NOLINT
300+
auto& host_module = std::get<0>(_host_module_device_module_);
301+
auto& device_module = std::get<1>(_host_module_device_module_);
302+
sycl::CodeGenSyclDevice codegen(target_);
303+
auto source_code = codegen.Compile(device_module);
304+
return source_code;
305+
#else
306+
CINN_NOT_IMPLEMENTED
307+
#endif
308+
});
292309
}
293310

294311
void Compiler::BuildDefault(const Module& module) {
@@ -393,7 +410,38 @@ void Compiler::RegisterHipModuleSymbol() {
393410
#endif
394411
}
395412

396-
void Compiler::RegisterSyclModuleSymbol() { CINN_NOT_IMPLEMENTED }
413+
void Compiler::RegisterSyclModuleSymbol() {
414+
#ifdef CINN_WITH_SYCL
415+
syclrtc::Compiler compiler;
416+
std::string source_code =
417+
sycl::CodeGenSyclDevice::GetSourceHeader() + device_fn_code_;
418+
std::string hsaco = compiler(source_code);
419+
PADDLE_ENFORCE_EQ(
420+
!hsaco.empty(),
421+
true,
422+
::common::errors::Fatal("Compile hsaco failed from source code:\n%s",
423+
source_code));
424+
using runtime::sycl::SYCLModule;
425+
sycl_module_.reset(new SYCLModule(source_code, hsaco, SYCLModule::Kind::so));
426+
// get device id
427+
using cinn::runtime::BackendAPI;
428+
int device_id = BackendAPI::get_backend(target_)->get_device();
429+
// register kernel
430+
RuntimeSymbols symbols;
431+
for (const auto& kernel_fn_name : device_fn_name_) {
432+
auto fn_kernel = sycl_module_->GetFunction(kernel_fn_name);
433+
PADDLE_ENFORCE_NOT_NULL(
434+
fn_kernel,
435+
::common::errors::Fatal("HIP GetFunction Error: get valid kernel."));
436+
fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel));
437+
symbols.RegisterVar(kernel_fn_name + "_ptr_",
438+
reinterpret_cast<void*>(fn_kernel));
439+
}
440+
engine_->RegisterModuleRuntimeSymbols(std::move(symbols));
441+
#else
442+
CINN_NOT_IMPLEMENTED
443+
#endif
444+
}
397445

398446
void Compiler::CompileCudaModule(const Module& module,
399447
const std::string& code) {
@@ -473,7 +521,39 @@ void Compiler::CompileHipModule(const Module& module, const std::string& code) {
473521

474522
void Compiler::CompileSyclModule(const Module& module,
475523
const std::string& code) {
524+
#ifdef CINN_WITH_SYCL
525+
auto _host_module_device_module_ =
526+
SplitDeviceAndHostModule(module); // NOLINT
527+
auto& host_module = std::get<0>(_host_module_device_module_);
528+
auto& device_module = std::get<1>(_host_module_device_module_);
529+
VLOG(3) << "[SYCL] host module:\n" << host_module;
530+
VLOG(3) << "[SYCL] device module:\n" << device_module;
531+
std::string source_code;
532+
if (!FLAGS_cinn_debug_custom_code_path.empty()) {
533+
std::string file_path = FLAGS_cinn_debug_custom_code_path;
534+
source_code = GetFileContent(file_path);
535+
} else if (code.empty()) {
536+
sycl::CodeGenSyclDevice codegen(target_);
537+
source_code = codegen.Compile(device_module);
538+
} else {
539+
source_code = code;
540+
}
541+
PADDLE_ENFORCE_EQ(
542+
!source_code.empty(),
543+
true,
544+
::common::errors::Fatal(
545+
"Compile SYCL code failed from device module:\n%s", device_module));
546+
VLOG(3) << "[SYCL]:\n" << source_code;
547+
SourceCodePrint::GetInstance()->write(source_code);
548+
device_fn_code_ += source_code;
549+
for (auto& fn : device_module.functions()) {
550+
std::string kernel_fn_name = fn->name;
551+
device_fn_name_.emplace_back(kernel_fn_name);
552+
}
553+
engine_->Link<CodeGenGpuHost>(host_module);
554+
#else
476555
CINN_NOT_IMPLEMENTED
556+
#endif
477557
}
478558

479559
void Compiler::CompileX86Module(const Module& module) {

paddle/cinn/backends/compiler.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
#ifdef CINN_WITH_HIP
3333
#include "paddle/cinn/runtime/hip/hip_module.h"
3434
#endif
35+
#ifdef CINN_WITH_SYCL
36+
#include "paddle/cinn/runtime/sycl/sycl_module.h"
37+
#endif
3538

3639
namespace cinn {
3740
namespace backends {
@@ -167,6 +170,9 @@ class Compiler final {
167170
#ifdef CINN_WITH_HIP
168171
std::unique_ptr<runtime::hip::HIPModule> hip_module_;
169172
#endif
173+
#ifdef CINN_WITH_SYCL
174+
std::unique_ptr<runtime::sycl::SYCLModule> sycl_module_;
175+
#endif
170176
};
171177

172178
} // namespace backends

paddle/cinn/common/target.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,10 @@ const Target &DefaultHygonDcuSyclTarget() {
359359
const Target &DefaultDeviceTarget() {
360360
#ifdef CINN_WITH_CUDA
361361
return DefaultNVGPUTarget();
362-
#elif defined(CINN_WITH_HIP)
363-
return DefaultHygonDcuHipTarget();
364362
#elif defined(CINN_WITH_SYCL)
365363
return DefaultHygonDcuSyclTarget();
364+
#elif defined(CINN_WITH_HIP)
365+
return DefaultHygonDcuHipTarget();
366366
#endif
367367
}
368368

@@ -400,10 +400,10 @@ int GetMaxBlocks() {
400400
const Target &DefaultTarget() {
401401
#ifdef CINN_WITH_CUDA
402402
return DefaultNVGPUTarget();
403-
#elif defined(CINN_WITH_HIP)
404-
return DefaultHygonDcuHipTarget();
405403
#elif defined(CINN_WITH_SYCL)
406404
return DefaultHygonDcuSyclTarget();
405+
#elif defined(CINN_WITH_HIP)
406+
return DefaultHygonDcuHipTarget();
407407
#else
408408
return DefaultHostTarget();
409409
#endif

paddle/cinn/optim/optimize.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
107107
VLOG(10) << "After Optimize TransBufferWithDynamicShape:" << copied;
108108
#endif
109109
},
110-
[&](common::HygonDCUArchHIP) {
110+
[&](std::variant<common::HygonDCUArchHIP, common::HygonDCUArchSYCL>) {
111111
#ifdef CINN_WITH_HIP
112112
ir::SetCudaAxisInfo(copied);
113113
if (remove_gpu_for_loops) {
@@ -124,7 +124,6 @@ ir::LoweredFunc Optimize(ir::LoweredFunc fn,
124124
VLOG(10) << "After Optimize CudaSyncThreadsDropIfThenElse:" << copied;
125125
#endif
126126
},
127-
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED },
128127
[](auto) {});
129128

130129
SimplifyUnitBlock(&copied->body);

paddle/cinn/optim/trans_buffer_with_dynamic_shape.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,19 @@ LogicalResult TransBufferWithDynamicShapePass::Run(ir::LoweredFunc func) {
200200
"The shared memory size used by current kernel is greater "
201201
"than the max shared memory per block"));
202202
},
203-
[&](common::HygonDCUArchSYCL) { CINN_NOT_IMPLEMENTED });
203+
[&](common::HygonDCUArchSYCL) {
204+
using cinn::runtime::BackendAPI;
205+
size_t max_shm_per_block =
206+
BackendAPI::get_backend(common::HygonDCUArchSYCL{})
207+
->get_device_property(
208+
BackendAPI::DeviceProperty::MaxSharedMemoryPerBlock);
209+
PADDLE_ENFORCE_LE(
210+
mutator.shared_mem_size_used(),
211+
max_shm_per_block,
212+
::common::errors::InvalidArgument(
213+
"The shared memory size used by current kernel is greater "
214+
"than the max shared memory per block"));
215+
});
204216
return LogicalResult::success();
205217
}
206218

paddle/cinn/runtime/sycl/sycl_backend_api.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ void SYCLBackendAPI::Init(Arch arch) {
4545
backend = ::sycl::backend::ext_oneapi_hip;
4646
});
4747
// look for matched devices
48-
for (auto device : devices) {
49-
if (device.get_backend() == backend) {
50-
this->devices.push_back(device);
48+
if (this->devices.size() < 8) {
49+
for (auto device : devices) {
50+
if (device.get_backend() == backend) {
51+
this->devices.push_back(device);
52+
}
5153
}
5254
}
5355
if (this->devices.size() == 0) {

paddle/cinn/runtime/use_extern_funcs.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#ifdef CINN_WITH_CUDA
1919
#include "paddle/cinn/runtime/cuda/use_extern_funcs.h"
2020
#endif
21+
2122
#ifdef CINN_WITH_HIP
2223
#include "paddle/cinn/runtime/hip/use_extern_funcs.h"
2324
#endif
25+
26+
#ifdef CINN_WITH_SYCL
27+
#include "paddle/cinn/runtime/sycl/use_extern_funcs.h"
28+
#endif

0 commit comments

Comments
 (0)