|
35 | 35 | #include "paddle/cinn/backends/hip/compiler_hip.h" |
36 | 36 | #include "paddle/cinn/runtime/hip/hip_module.h" |
37 | 37 | #endif |
| 38 | +#ifdef CINN_WITH_SYCL |
| 39 | +#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h" |
| 40 | +#include "paddle/cinn/backends/sycl/compiler_sycl.h" |
| 41 | +#include "paddle/cinn/runtime/sycl/sycl_module.h" |
| 42 | +#endif |
38 | 43 | #include "paddle/cinn/adt/adt.h" |
39 | 44 |
|
40 | 45 | PD_DECLARE_string(cinn_source_code_save_path); |
@@ -288,7 +293,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) { |
288 | 293 | CINN_NOT_IMPLEMENTED |
289 | 294 | #endif |
290 | 295 | }, |
291 | | - [&](common::HygonDCUArchSYCL) -> std::string { CINN_NOT_IMPLEMENTED }); |
| 296 | + [&](common::HygonDCUArchSYCL) -> std::string { |
| 297 | +#ifdef CINN_WITH_SYCL |
| 298 | + auto _host_module_device_module_ = |
| 299 | + SplitDeviceAndHostModule(module); // NOLINT |
| 300 | + auto& host_module = std::get<0>(_host_module_device_module_); |
| 301 | + auto& device_module = std::get<1>(_host_module_device_module_); |
| 302 | + sycl::CodeGenSyclDevice codegen(target_); |
| 303 | + auto source_code = codegen.Compile(device_module); |
| 304 | + return source_code; |
| 305 | +#else |
| 306 | + CINN_NOT_IMPLEMENTED |
| 307 | +#endif |
| 308 | + }); |
292 | 309 | } |
293 | 310 |
|
294 | 311 | void Compiler::BuildDefault(const Module& module) { |
@@ -393,7 +410,38 @@ void Compiler::RegisterHipModuleSymbol() { |
393 | 410 | #endif |
394 | 411 | } |
395 | 412 |
|
396 | | -void Compiler::RegisterSyclModuleSymbol() { CINN_NOT_IMPLEMENTED } |
| 413 | +void Compiler::RegisterSyclModuleSymbol() { |
| 414 | +#ifdef CINN_WITH_SYCL |
| 415 | + syclrtc::Compiler compiler; |
| 416 | + std::string source_code = |
| 417 | + sycl::CodeGenSyclDevice::GetSourceHeader() + device_fn_code_; |
| 418 | + std::string hsaco = compiler(source_code); |
| 419 | + PADDLE_ENFORCE_EQ( |
| 420 | + !hsaco.empty(), |
| 421 | + true, |
| 422 | + ::common::errors::Fatal("Compile hsaco failed from source code:\n%s", |
| 423 | + source_code)); |
| 424 | + using runtime::sycl::SYCLModule; |
| 425 | + sycl_module_.reset(new SYCLModule(source_code, hsaco, SYCLModule::Kind::so)); |
| 426 | + // get device id |
| 427 | + using cinn::runtime::BackendAPI; |
| 428 | + int device_id = BackendAPI::get_backend(target_)->get_device(); |
| 429 | + // register kernel |
| 430 | + RuntimeSymbols symbols; |
| 431 | + for (const auto& kernel_fn_name : device_fn_name_) { |
| 432 | + auto fn_kernel = sycl_module_->GetFunction(kernel_fn_name); |
| 433 | + PADDLE_ENFORCE_NOT_NULL( |
| 434 | + fn_kernel, |
| 435 | + ::common::errors::Fatal("HIP GetFunction Error: get valid kernel.")); |
| 436 | + fn_ptr_.push_back(reinterpret_cast<void*>(fn_kernel)); |
| 437 | + symbols.RegisterVar(kernel_fn_name + "_ptr_", |
| 438 | + reinterpret_cast<void*>(fn_kernel)); |
| 439 | + } |
| 440 | + engine_->RegisterModuleRuntimeSymbols(std::move(symbols)); |
| 441 | +#else |
| 442 | + CINN_NOT_IMPLEMENTED |
| 443 | +#endif |
| 444 | +} |
397 | 445 |
|
398 | 446 | void Compiler::CompileCudaModule(const Module& module, |
399 | 447 | const std::string& code) { |
@@ -473,7 +521,39 @@ void Compiler::CompileHipModule(const Module& module, const std::string& code) { |
473 | 521 |
|
474 | 522 | void Compiler::CompileSyclModule(const Module& module, |
475 | 523 | const std::string& code) { |
| 524 | +#ifdef CINN_WITH_SYCL |
| 525 | + auto _host_module_device_module_ = |
| 526 | + SplitDeviceAndHostModule(module); // NOLINT |
| 527 | + auto& host_module = std::get<0>(_host_module_device_module_); |
| 528 | + auto& device_module = std::get<1>(_host_module_device_module_); |
| 529 | + VLOG(3) << "[SYCL] host module:\n" << host_module; |
| 530 | + VLOG(3) << "[SYCL] device module:\n" << device_module; |
| 531 | + std::string source_code; |
| 532 | + if (!FLAGS_cinn_debug_custom_code_path.empty()) { |
| 533 | + std::string file_path = FLAGS_cinn_debug_custom_code_path; |
| 534 | + source_code = GetFileContent(file_path); |
| 535 | + } else if (code.empty()) { |
| 536 | + sycl::CodeGenSyclDevice codegen(target_); |
| 537 | + source_code = codegen.Compile(device_module); |
| 538 | + } else { |
| 539 | + source_code = code; |
| 540 | + } |
| 541 | + PADDLE_ENFORCE_EQ( |
| 542 | + !source_code.empty(), |
| 543 | + true, |
| 544 | + ::common::errors::Fatal( |
| 545 | + "Compile SYCL code failed from device module:\n%s", device_module)); |
| 546 | + VLOG(3) << "[SYCL]:\n" << source_code; |
| 547 | + SourceCodePrint::GetInstance()->write(source_code); |
| 548 | + device_fn_code_ += source_code; |
| 549 | + for (auto& fn : device_module.functions()) { |
| 550 | + std::string kernel_fn_name = fn->name; |
| 551 | + device_fn_name_.emplace_back(kernel_fn_name); |
| 552 | + } |
| 553 | + engine_->Link<CodeGenGpuHost>(host_module); |
| 554 | +#else |
476 | 555 | CINN_NOT_IMPLEMENTED |
| 556 | +#endif |
477 | 557 | } |
478 | 558 |
|
479 | 559 | void Compiler::CompileX86Module(const Module& module) { |
|
0 commit comments