@@ -238,7 +238,8 @@ def parse_kernel(self, kernel_config):
238238 'param' : None ,
239239 'backend' : None ,
240240 'layout' : None ,
241- 'data_type' : None
241+ 'data_type' : None ,
242+ 'use_cudnn' : 'false'
242243 }
243244 if 'backend' in kernel_config and len (kernel_config ['backend' ]) > 0 :
244245 kernel ['backend' ] = kernel_config ['backend' ]
@@ -248,6 +249,10 @@ def parse_kernel(self, kernel_config):
248249 kernel ['data_type' ] = kernel_config ['data_type' ]
249250 if 'param' in kernel_config :
250251 kernel ['param' ] = kernel_config ['param' ]
252+ if 'use_cudnn' in kernel_config :
253+ kernel ['use_cudnn' ] = kernel_config ['use_cudnn' ]
254+ if isinstance (kernel ['use_cudnn' ], bool ):
255+ kernel ['use_cudnn' ] = str (kernel ['use_cudnn' ]).lower ()
251256 kernel ['func' ] = [
252257 kernel_fn .strip () for kernel_fn in kernel_config ['func' ].split (',' )
253258 ]
@@ -713,10 +718,12 @@ def gen_dense_tensor_kernel_code(self, code_indent, inplace_flag=False):
713718 outputs_args , kernel_output_names , output_create = self .gene_output (
714719 self .outputs ['types' ], 'SetKernelOutput' , code_indent , inplace_flag )
715720 api_func_name = self .get_api_func_name () + ('_' if inplace_flag else '' )
721+ cudnn_args = '' if self .kernel [
722+ 'use_cudnn' ] == 'false' else ', ' + self .kernel ['use_cudnn' ]
716723 return f"""
717724{ code_indent } VLOG(6) << "{ self .api } API kernel key: [" << kernel_backend << ", " << kernel_layout << ", "<< kernel_data_type << "]";
718725{ code_indent } const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
719- { code_indent } "{ self .kernel ['func' ][0 ]} ", {{kernel_backend, kernel_layout, kernel_data_type}});
726+ { code_indent } "{ self .kernel ['func' ][0 ]} ", {{kernel_backend, kernel_layout, kernel_data_type}}{ cudnn_args } );
720727{ code_indent } VLOG(6) << "{ self .api } API kernel: " << kernel;
721728
722729{ code_indent } auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
0 commit comments