Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ set(XPU_XFA_LIB_NAME "libxpu_flash_attention.so")
set(XPU_XPUDNN_LIB_NAME "libxpu_dnn.so")

if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "dev/20250304")
set(XPU_XHPC_BASE_DATE "dev/20250305")
endif()
set(XPU_XCCL_BASE_VERSION "3.0.2.3") # For XRE5
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ XPUOpMap& get_kl3_ops() {
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16})},
{"conv3d_transpose",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"conv2d_transpose",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/kernels/xpu/conv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,13 @@ void Conv3DKernel(const Context& dev_ctx,
// that avoids modifying the variable in the Scope.
dev_ctx.template Alloc<T>(out);

phi::DDim in_data_dims =
common::slice_ddim(input.dims(), 2, input.dims().size());
phi::DDim in_data_dims;
if (data_format == "NDHWC") {
in_data_dims = common::slice_ddim(input.dims(), 1, input.dims().size() - 1);
} else {
in_data_dims = common::slice_ddim(input.dims(), 2, input.dims().size());
}

phi::DDim filter_data_dims =
common::slice_ddim(filter.dims(), 2, filter.dims().size());
std::vector<int64_t> ksize = common::vectorize<int64_t>(filter_data_dims);
Expand Down
138 changes: 138 additions & 0 deletions paddle/phi/kernels/xpu/conv_transpose_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
#include "paddle/phi/kernels/xpu/conv_utils_xpu.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
#ifdef PADDLE_WITH_XPU_XRE5
#include "xpudnn/xpudnn.h"
Expand Down Expand Up @@ -352,6 +353,136 @@ void Conv2dTransposeKernel(const Context& ctx,
}
#endif
}

template <typename T, typename Context>
void Conv3dTransposeKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& filter,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::vector<int>& output_padding,
const std::vector<int>& output_size,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;

dev_ctx.template Alloc<T>(out);

// data_format follow the legacy interface:
// https://github.com/PaddlePaddle/Paddle/blob/639abfd4/python/paddle/nn/functional/conv.py#L1726
PADDLE_ENFORCE_EQ(
data_format == "NCHW" || data_format == "NHWC",
true,
errors::InvalidArgument(
("XPU only support data_format is NCHW(in Python, it is specified as "
"NCDHW) or NHWC(in Python, it is specified as NDHWC) in "
"conv3d_transpose op.")));

phi::DDim in_data_dims;
if (data_format == "NHWC") {
in_data_dims = common::slice_ddim(x.dims(), 1, x.dims().size() - 1);
} else {
in_data_dims = common::slice_ddim(x.dims(), 2, x.dims().size());
}
phi::DDim filter_data_dims =
common::slice_ddim(filter.dims(), 2, filter.dims().size());

std::vector<int64_t> ksize = common::vectorize<int64_t>(filter_data_dims);
std::vector<int64_t> paddings_ =
std::vector<int64_t>(paddings.begin(), paddings.end());
std::vector<int64_t> dilations_ =
std::vector<int64_t>(dilations.begin(), dilations.end());
std::vector<int64_t> strides_ =
std::vector<int64_t>(strides.begin(), strides.end());
UpdatePaddingAndDilation(&paddings_,
&dilations_,
padding_algorithm,
in_data_dims,
strides_,
ksize);

for (int64_t dilation : dilations_) {
PADDLE_ENFORCE_LE(
dilation,
1,
errors::Unimplemented(
"XPU do not support dilation > 1 in conv3d_transpose."));
}

int64_t batch_size = x.dims()[0];
int64_t img_yc = x.dims()[1];
int64_t img_yd = x.dims()[2];
int64_t img_yh = x.dims()[3];
int64_t img_yw = x.dims()[4];
int64_t img_xc = out->dims()[1];

bool is_ndhwc = false;
if (data_format == "NHWC") {
img_yc = x.dims()[4];
img_yd = x.dims()[1];
img_yh = x.dims()[2];
img_yw = x.dims()[3];
img_xc = out->dims()[4];
is_ndhwc = true;

PADDLE_ENFORCE_LE(
groups,
1,
errors::Unimplemented("XPU do not support group > 1 when data_format "
"is NHWC(in Python, it is specified as NDHWC) "
"in conv3d_transpose."));
}

const XPUType* filter_data =
reinterpret_cast<const XPUType*>(filter.data<T>());

int fc_calc_type = GetConvCalcType<XPUType>();
PD_VISIT_XPU_CONV_TYPES(XPUType, fc_calc_type, "conv3d_transpose", [&] {
using XPUTypeFP16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
using RealTGEMM = std::conditional_t<
(
// 如果 XPUType 是 XPUTypeFP16 且 TGEMM 不是 FP16 或 int16
(std::is_same_v<XPUType, XPUTypeFP16> &&
!std::is_same_v<TGEMM, XPUTypeFP16> &&
!std::is_same_v<TGEMM, int16_t>) ||

// 如果 XPUType 是 float 且 TGEMM 不是 int32、int16 或 tfloat32
(std::is_same_v<XPUType, float> &&
!std::is_same_v<TGEMM, int32_t> &&
!std::is_same_v<TGEMM, int16_t> &&
!std::is_same_v<TGEMM, tfloat32>)),
std::conditional_t<std::is_same_v<XPUType, XPUTypeFP16>,
XPUTypeFP16,
tfloat32>,
TGEMM>;

int ret = xpudnn::conv3d_transpose<XPUType, XPUType, XPUType, RealTGEMM>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
filter_data, // filter的shape固定为[yc, xc, fd, fh, fw],
reinterpret_cast<XPUType*>(out->data<T>()),
batch_size,
img_yc,
img_yd,
img_yh,
img_yw,
img_xc,
ksize,
strides_,
paddings_,
dilations_,
groups,
nullptr,
nullptr,
nullptr,
is_ndhwc);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "conv3d_transpose");
});
}

template <typename T, typename Context>
void DepthwiseConv2dTransposeKernel(const Context& ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -393,3 +524,10 @@ PD_REGISTER_KERNEL(conv2d_transpose,
phi::Conv2dTransposeKernel,
float,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(conv3d_transpose,
XPU,
ALL_LAYOUT,
phi::Conv3dTransposeKernel,
float,
phi::dtype::float16) {}
Loading