Skip to content

Commit 1888d87

Browse files
authored
add cudnn flag in yaml (#41368)
1 parent 77cf305 commit 1888d87

File tree

5 files changed

+34
-4
lines changed

5 files changed

+34
-4
lines changed

paddle/phi/core/kernel_factory.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,31 @@ bool KernelFactory::IsSelectKernelValid(const std::string& kernel_name,
7575
}
7676

7777
const Kernel& KernelFactory::SelectKernelOrThrowError(
78-
const std::string& kernel_name, const KernelKey& kernel_key) const {
78+
const std::string& kernel_name,
79+
const KernelKey& kernel_key,
80+
bool use_cudnn) const {
7981
auto iter = kernels_.find(kernel_name);
8082
PADDLE_ENFORCE_NE(
8183
iter,
8284
kernels_.end(),
8385
phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name));
8486

87+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
88+
if (use_cudnn && kernel_key.backend() == Backend::GPU) {
89+
auto kernel_iter = iter->second.find(
90+
{Backend::GPUDNN, kernel_key.layout(), kernel_key.dtype()});
91+
if (kernel_iter == iter->second.end() &&
92+
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
93+
kernel_iter = iter->second.find(
94+
{Backend::GPUDNN, DataLayout::ALL_LAYOUT, kernel_key.dtype()});
95+
}
96+
if (kernel_iter != iter->second.end()) {
97+
return kernel_iter->second;
98+
}
99+
LOG(WARNING) << "The cudnn kernel for [" << kernel_name
100+
<< "] is not registered.";
101+
}
102+
#endif
85103
auto kernel_iter = iter->second.find(kernel_key);
86104
// TODO(chenweihang): polish refind impl here
87105
if (kernel_iter == iter->second.end() &&

paddle/phi/core/kernel_factory.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ class KernelFactory {
238238
}
239239

240240
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
241-
const KernelKey& kernel_key) const;
241+
const KernelKey& kernel_key,
242+
bool use_cudnn = false) const;
242243

243244
const Kernel& SelectKernelOrThrowError(const std::string& kernel_name,
244245
Backend backend,

python/paddle/utils/code_gen/api_base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ def parse_kernel(self, kernel_config):
238238
'param': None,
239239
'backend': None,
240240
'layout': None,
241-
'data_type': None
241+
'data_type': None,
242+
'use_cudnn': 'false'
242243
}
243244
if 'backend' in kernel_config and len(kernel_config['backend']) > 0:
244245
kernel['backend'] = kernel_config['backend']
@@ -248,6 +249,10 @@ def parse_kernel(self, kernel_config):
248249
kernel['data_type'] = kernel_config['data_type']
249250
if 'param' in kernel_config:
250251
kernel['param'] = kernel_config['param']
252+
if 'use_cudnn' in kernel_config:
253+
kernel['use_cudnn'] = kernel_config['use_cudnn']
254+
if isinstance(kernel['use_cudnn'], bool):
255+
kernel['use_cudnn'] = str(kernel['use_cudnn']).lower()
251256
kernel['func'] = [
252257
kernel_fn.strip() for kernel_fn in kernel_config['func'].split(',')
253258
]
@@ -713,10 +718,12 @@ def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False):
713718
outputs_args, kernel_output_names, output_create = self.gene_output(
714719
self.outputs['types'], 'SetKernelOutput', code_indent, inplace_flag)
715720
api_func_name = self.get_api_func_name() + ('_' if inplace_flag else '')
721+
cudnn_args = '' if self.kernel[
722+
'use_cudnn'] == 'false' else ', ' + self.kernel['use_cudnn']
716723
return f"""
717724
{code_indent} VLOG(6) << "{self.api} API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
718725
{code_indent} const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
719-
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}});
726+
{code_indent} "{self.kernel['func'][0]}", {{kernel_backend, kernel_layout, kernel_data_type}}{cudnn_args});
720727
{code_indent} VLOG(6) << "{self.api} API kernel: " << kernel;
721728
722729
{code_indent} auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);

python/paddle/utils/code_gen/api_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def source_include(header_file_path):
163163
#include "paddle/phi/infermeta/ternary.h"
164164
165165
#include "paddle/fluid/platform/profiler/event_tracing.h"
166+
167+
DECLARE_bool(conv2d_disable_cudnn);
166168
"""
167169

168170

python/paddle/utils/code_gen/backward_api_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ def source_include(header_file_path):
179179
180180
#include "paddle/fluid/eager/api/utils/global_utils.h"
181181
#include "paddle/fluid/platform/profiler/event_tracing.h"
182+
183+
DECLARE_bool(conv2d_disable_cudnn);
182184
"""
183185

184186

0 commit comments

Comments
 (0)