Skip to content

Commit 65e5eb4

Browse files
authored
[CINN][New Hardware]Fix dcu compile (#69159)
* hip_squeezenet * fix paddle hip compile * fix_hip_compile * fix_dcu_compile * fix_dcu_compile * fix_dcu_compile * fix_dcu_compile * fix_dcu_compile * fix_dcu_compile
1 parent 7658a9f commit 65e5eb4

File tree

5 files changed

+30
-16
lines changed

5 files changed

+30
-16
lines changed

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
217217
},
218218
[&](common::HygonDCUArchHIP) {
219219
#ifdef CINN_WITH_HIP
220-
shared_mem_bytes = hip::CalculateSharedMemory(func);
220+
shared_mem_bytes = CalculateSharedMemory(func);
221221
#endif
222222
});
223223

paddle/cinn/common/dev_info_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/cinn/common/macros.h"
2121
#include "paddle/cinn/common/nvgpu_dev_info.h"
2222
#include "paddle/cinn/common/target.h"
23+
#include "paddle/common/enforce.h"
2324

2425
namespace cinn {
2526
namespace common {

paddle/cinn/common/target.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ std::string Target::arch_str() const {
242242
}
243243

244244
std::string Target::device_name_str() const {
245+
#ifdef CINN_WITH_CUDA
245246
int device_idx = 0;
246247
cudaError_t result = cudaGetDevice(&device_idx);
247248
if (result != cudaSuccess) {
@@ -265,6 +266,9 @@ std::string Target::device_name_str() const {
265266
std::string device_name = properties.name;
266267
device_name = std::regex_replace(device_name, std::regex(" "), "_");
267268
return std::regex_replace(device_name, std::regex("-"), "_");
269+
#else
270+
CINN_NOT_IMPLEMENTED
271+
#endif
268272
}
269273

270274
std::ostream &operator<<(std::ostream &os, const Target &target) {

paddle/cinn/hlir/framework/pir/op_lowering_impl.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ ir::Expr OpLowererImpl::LowerX86(const OpLoweringGroupPtr& group,
798798

799799
std::vector<ir::Expr> func_bodies =
800800
LowerOps(group, ops, &group_func_arg_tensors, &tensor_map);
801-
this->target_ = common::DefaultNVGPUTarget();
801+
this->target_ = common::DefaultDeviceTarget();
802802
cinn::runtime::CurrentTarget::SetCurrentTarget(this->target_);
803803
ir::ModuleExpr mod_expr(func_bodies);
804804
ir::IRSchedule ir_sch(

paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "paddle/common/errors.h"
2121
#include "paddle/common/performance_statistician.h"
2222
#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h"
23+
#include "paddle/phi/backends/gpu/gpu_info.h"
24+
#include "paddle/phi/backends/gpu/gpu_resources.h"
2325
#if defined(PADDLE_WITH_CUDA)
2426
#include "paddle/cinn/runtime/cinn_runtime.h"
2527
#endif
@@ -112,32 +114,39 @@ class CinnJitInstruction::FnPtrImpl {
112114
::common::PerformanceStatistician& ps =
113115
::common::PerformanceStatistician::Instance();
114116
auto data_p = static_cast<void*>(func_args_.data());
115-
cudaStream_t stream;
116-
cudaStreamCreate(&stream);
117-
cudaDeviceSynchronize();
117+
phi::gpuStream_t stream;
118+
phi::InitStream(&stream);
119+
phi::backends::gpu::GpuDeviceSync();
118120
if (is_gpu) {
119121
ps.SetGraphNodesNum(25);
120122
int graph_nodes_num = ps.GetGraphNodesNum();
121-
cudaGraph_t graph;
122-
cudaGraphExec_t instance;
123-
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
123+
phi::gpuGraph_t graph;
124+
phi::gpuGraphExec_t instance;
125+
phi::gpuStreamBeginCapture(
126+
stream, gpuStreamCaptureMode(0)); // StreamCaptureModeGlobal
124127
for (int ikrnl = 0; ikrnl < graph_nodes_num; ikrnl++) {
125128
((lower_func_ptr_g)cinn_kernel_info_.fn_ptr)(
126129
static_cast<void*>(func_args_.data()), func_args_.size(), stream);
127130
}
128-
cudaStreamEndCapture(stream, &graph);
131+
phi::gpuStreamEndCapture(stream, &graph);
132+
#ifdef PADDLE_WITH_CUDA
129133
cudaGraphInstantiate(&instance, graph, NULL, NULL, 0);
134+
#elif defined(PADDLE_WITH_HIP)
135+
hipGraphInstantiate(&instance, graph, NULL, NULL, 0);
136+
#else
137+
CINN_NOT_IMPLEMENTED
138+
#endif
130139
ps.CudaStart(FLAGS_cinn_kernel_execution_label);
131-
cudaGraphLaunch(instance, stream);
140+
phi::gpuGraphLaunch(instance, stream);
132141
ps.CudaEnd(FLAGS_cinn_kernel_execution_label);
133-
cudaGraphDestroy(graph);
134-
cudaGraphExecDestroy(instance);
135-
cudaStreamDestroy(stream);
142+
phi::gpuGraphDestroy(graph);
143+
phi::gpuGraphExecDestroy(instance);
144+
phi::DestoryStream(stream);
136145
} else {
137146
((lower_func_ptr_g)cinn_kernel_info_.CX86_fn_ptr)(
138147
static_cast<void*>(func_args_.data()), func_args_.size(), stream);
139148
}
140-
cudaDeviceSynchronize();
149+
phi::backends::gpu::GpuDeviceSync();
141150
} else {
142151
if (is_gpu) {
143152
((lower_func_ptr_g)cinn_kernel_info_.fn_ptr)(
@@ -267,7 +276,7 @@ CinnJitInstruction::CinnJitInstruction(
267276
}
268277

269278
void CinnJitInstruction::Run() {
270-
#if defined(PADDLE_WITH_CUDA)
279+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
271280
void* running_stream = nullptr;
272281
bool is_gpu = false;
273282

@@ -298,7 +307,7 @@ void CinnJitInstruction::Run() {
298307
}
299308
#else
300309
VLOG(0) << "Not Supported: cinn jit instruction currently does not "
301-
"support non-CUDA kernel";
310+
"support CUDA/HIP kernel";
302311
#endif
303312
}
304313

0 commit comments

Comments
 (0)