|
20 | 20 | #include "paddle/common/errors.h" |
21 | 21 | #include "paddle/common/performance_statistician.h" |
22 | 22 | #include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" |
| 23 | +#include "paddle/phi/backends/gpu/gpu_info.h" |
| 24 | +#include "paddle/phi/backends/gpu/gpu_resources.h" |
23 | 25 | #if defined(PADDLE_WITH_CUDA) |
24 | 26 | #include "paddle/cinn/runtime/cinn_runtime.h" |
25 | 27 | #endif |
@@ -112,32 +114,39 @@ class CinnJitInstruction::FnPtrImpl { |
112 | 114 | ::common::PerformanceStatistician& ps = |
113 | 115 | ::common::PerformanceStatistician::Instance(); |
114 | 116 | auto data_p = static_cast<void*>(func_args_.data()); |
115 | | - cudaStream_t stream; |
116 | | - cudaStreamCreate(&stream); |
117 | | - cudaDeviceSynchronize(); |
| 117 | + phi::gpuStream_t stream; |
| 118 | + phi::InitStream(&stream); |
| 119 | + phi::backends::gpu::GpuDeviceSync(); |
118 | 120 | if (is_gpu) { |
119 | 121 | ps.SetGraphNodesNum(25); |
120 | 122 | int graph_nodes_num = ps.GetGraphNodesNum(); |
121 | | - cudaGraph_t graph; |
122 | | - cudaGraphExec_t instance; |
123 | | - cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal); |
| 123 | + phi::gpuGraph_t graph; |
| 124 | + phi::gpuGraphExec_t instance; |
| 125 | + phi::gpuStreamBeginCapture( |
| 126 | + stream, gpuStreamCaptureMode(0)); // StreamCaptureModeGlobal |
124 | 127 | for (int ikrnl = 0; ikrnl < graph_nodes_num; ikrnl++) { |
125 | 128 | ((lower_func_ptr_g)cinn_kernel_info_.fn_ptr)( |
126 | 129 | static_cast<void*>(func_args_.data()), func_args_.size(), stream); |
127 | 130 | } |
128 | | - cudaStreamEndCapture(stream, &graph); |
| 131 | + phi::gpuStreamEndCapture(stream, &graph); |
| 132 | +#ifdef PADDLE_WITH_CUDA |
129 | 133 | cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); |
| 134 | +#elif defined(PADDLE_WITH_HIP) |
| 135 | + hipGraphInstantiate(&instance, graph, NULL, NULL, 0); |
| 136 | +#else |
| 137 | + CINN_NOT_IMPLEMENTED |
| 138 | +#endif |
130 | 139 | ps.CudaStart(FLAGS_cinn_kernel_execution_label); |
131 | | - cudaGraphLaunch(instance, stream); |
| 140 | + phi::gpuGraphLaunch(instance, stream); |
132 | 141 | ps.CudaEnd(FLAGS_cinn_kernel_execution_label); |
133 | | - cudaGraphDestroy(graph); |
134 | | - cudaGraphExecDestroy(instance); |
135 | | - cudaStreamDestroy(stream); |
| 142 | + phi::gpuGraphDestroy(graph); |
| 143 | + phi::gpuGraphExecDestroy(instance); |
| 144 | + phi::DestoryStream(stream); |
136 | 145 | } else { |
137 | 146 | ((lower_func_ptr_g)cinn_kernel_info_.CX86_fn_ptr)( |
138 | 147 | static_cast<void*>(func_args_.data()), func_args_.size(), stream); |
139 | 148 | } |
140 | | - cudaDeviceSynchronize(); |
| 149 | + phi::backends::gpu::GpuDeviceSync(); |
141 | 150 | } else { |
142 | 151 | if (is_gpu) { |
143 | 152 | ((lower_func_ptr_g)cinn_kernel_info_.fn_ptr)( |
@@ -267,7 +276,7 @@ CinnJitInstruction::CinnJitInstruction( |
267 | 276 | } |
268 | 277 |
|
269 | 278 | void CinnJitInstruction::Run() { |
270 | | -#if defined(PADDLE_WITH_CUDA) |
| 279 | +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) |
271 | 280 | void* running_stream = nullptr; |
272 | 281 | bool is_gpu = false; |
273 | 282 |
|
@@ -298,7 +307,7 @@ void CinnJitInstruction::Run() { |
298 | 307 | } |
299 | 308 | #else |
300 | 309 | VLOG(0) << "Not Supported: cinn jit instruction currently does not " |
301 | | - "support non-CUDA kernel"; |
| 310 | + "support CUDA/HIP kernel"; |
302 | 311 | #endif |
303 | 312 | } |
304 | 313 |
|
|
0 commit comments