Skip to content

Commit d54a0c1

Browse files
phlrainwu.zeng
authored andcommitted
add depthwise conv hip support (PaddlePaddle#41537)
1 parent f7c1643 commit d54a0c1

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

paddle/phi/kernels/gpudnn/conv_grad_kernel.cu

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,39 @@ void Conv3DCudnnGradKernel(const Context& dev_ctx,
627627
filter_grad);
628628
}
629629

630+
template <typename T, typename Context>
631+
void DepthwiseConvCudnnGradKernel(const Context& dev_ctx,
632+
const DenseTensor& input,
633+
const DenseTensor& filter,
634+
const DenseTensor& out_grad,
635+
const std::vector<int>& strides,
636+
const std::vector<int>& paddings,
637+
const std::string& paddding_algorithm,
638+
int groups,
639+
const std::vector<int>& dilations,
640+
const std::string& data_format,
641+
bool use_addto,
642+
int workspace_size_MB,
643+
bool exhaustive_search,
644+
DenseTensor* input_grad,
645+
DenseTensor* filter_grad) {
646+
ConvCudnnGradKernel<T>(dev_ctx,
647+
input,
648+
filter,
649+
out_grad,
650+
strides,
651+
paddings,
652+
paddding_algorithm,
653+
groups,
654+
dilations,
655+
data_format,
656+
use_addto,
657+
workspace_size_MB,
658+
exhaustive_search,
659+
input_grad,
660+
filter_grad);
661+
}
662+
630663
} // namespace phi
631664

632665
#ifdef PADDLE_WITH_HIP
@@ -643,6 +676,13 @@ PD_REGISTER_KERNEL(conv3d_grad,
643676
phi::Conv3DCudnnGradKernel,
644677
float,
645678
phi::dtype::float16) {}
679+
680+
PD_REGISTER_KERNEL(depthwise_conv2d_grad,
681+
GPUDNN,
682+
ALL_LAYOUT,
683+
phi::DepthwiseConvCudnnGradKernel,
684+
float,
685+
phi::dtype::float16) {}
646686
#else
647687
#if CUDNN_VERSION_MIN(8, 1, 0)
648688
PD_REGISTER_KERNEL(conv2d_grad,

paddle/phi/kernels/gpudnn/conv_kernel.cu

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,35 @@ void Conv3DCudnnKernel(const Context& dev_ctx,
416416
out);
417417
}
418418

419+
template <typename T, typename Context>
420+
void DepthwiseConvCudnnKernel(const Context& dev_ctx,
421+
const DenseTensor& input,
422+
const DenseTensor& filter,
423+
const std::vector<int>& strides,
424+
const std::vector<int>& paddings,
425+
const std::string& padding_algorithm,
426+
int groups,
427+
const std::vector<int>& dilations,
428+
const std::string& data_format,
429+
bool use_addto,
430+
int workspace_size_MB,
431+
bool exhaustive_search,
432+
DenseTensor* out) {
433+
ConvCudnnKernel<T>(dev_ctx,
434+
input,
435+
filter,
436+
strides,
437+
paddings,
438+
padding_algorithm,
439+
groups,
440+
dilations,
441+
data_format,
442+
use_addto,
443+
workspace_size_MB,
444+
exhaustive_search,
445+
out);
446+
}
447+
419448
} // namespace phi
420449

421450
#ifdef PADDLE_WITH_HIP
@@ -432,6 +461,14 @@ PD_REGISTER_KERNEL(conv3d,
432461
phi::Conv3DCudnnKernel,
433462
float,
434463
phi::dtype::float16) {}
464+
465+
PD_REGISTER_KERNEL(depthwise_conv2d,
466+
GPUDNN,
467+
ALL_LAYOUT,
468+
phi::DepthwiseConvCudnnKernel,
469+
float,
470+
phi::dtype::float16) {}
471+
435472
#else
436473
#if CUDNN_VERSION_MIN(8, 1, 0)
437474
PD_REGISTER_KERNEL(conv2d,

0 commit comments

Comments
 (0)