Skip to content

Commit b7f5753

Browse files
authored
[XPU] Fixed the mode error in pad3d (#9506)
1 parent 7098fad commit b7f5753

File tree

1 file changed

+64
-43
lines changed

1 file changed

+64
-43
lines changed

lite/kernels/xpu/pad3d_compute.cc

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,50 +36,71 @@ void Pad3dCompute<T>::Run() {
3636
auto* in_data = x->template data<T>();
3737
auto* out = param.Out;
3838
T* out_data = out->template mutable_data<T>(TARGET(kXPU));
39+
bool is_ncdhw;
40+
int n, c, d, h, w;
41+
if (data_format == "NCDHW") {
42+
is_ncdhw = true;
43+
n = in_dims[0];
44+
c = in_dims[1];
45+
d = in_dims[2];
46+
h = in_dims[3];
47+
w = in_dims[4];
48+
} else if (data_format == "NDHWC") {
49+
is_ncdhw = false;
50+
n = in_dims[0];
51+
c = in_dims[4];
52+
d = in_dims[1];
53+
h = in_dims[2];
54+
w = in_dims[3];
55+
} else {
56+
LOG(FATAL) << "xpu unsupport data_format: " << data_format;
57+
}
58+
// trans pad format
59+
std::vector<int> padding(6);
60+
padding[0] = pads[4];
61+
padding[1] = pads[5];
62+
padding[2] = pads[2];
63+
padding[3] = pads[3];
64+
padding[4] = pads[0];
65+
padding[5] = pads[1];
3966

40-
if (mode == "reflect" || mode == "constant" || mode == "replicate" ||
41-
mode == "circular") {
42-
if (data_format == "NCDHW") {
43-
std::vector<int> pad_left = {0, 0, pads[4], pads[2], pads[0]};
44-
std::vector<int> pad_right = {0, 0, pads[5], pads[3], pads[1]};
45-
46-
int n_shape = in_dims[0];
47-
int c_shape = in_dims[1];
48-
int d_shape = in_dims[2];
49-
int h_shape = in_dims[3];
50-
int w_shape = in_dims[4];
51-
52-
std::vector<int> xshape = {n_shape, c_shape, d_shape, h_shape, w_shape};
53-
54-
int r = xdnn::pad<T>(ctx.GetRawContext(),
55-
in_data,
56-
out_data,
57-
xshape,
58-
pad_left,
59-
pad_right,
60-
value);
61-
CHECK_EQ(r, 0);
62-
} else if (data_format == "NDHWC") {
63-
std::vector<int> pad_left = {0, pads[4], pads[2], pads[0], 0};
64-
std::vector<int> pad_right = {0, pads[5], pads[3], pads[1], 0};
65-
66-
int n_shape = in_dims[0];
67-
int d_shape = in_dims[1];
68-
int h_shape = in_dims[2];
69-
int w_shape = in_dims[3];
70-
int c_shape = in_dims[4];
71-
std::vector<int> xshape = {n_shape, d_shape, h_shape, w_shape, c_shape};
72-
73-
int r = xdnn::pad<T>(ctx.GetRawContext(),
74-
in_data,
75-
out_data,
76-
xshape,
77-
pad_left,
78-
pad_right,
79-
value);
80-
CHECK_EQ(r, 0);
81-
}
82-
67+
if (mode == "constant") {
68+
int r = xdnn::constant_pad3d<T>(ctx.GetRawContext(),
69+
in_data,
70+
out_data,
71+
n,
72+
c,
73+
d,
74+
h,
75+
w,
76+
padding,
77+
value,
78+
is_ncdhw);
79+
CHECK_EQ(r, 0);
80+
} else if (mode == "reflect") {
81+
int r = xdnn::reflection_pad3d<T>(ctx.GetRawContext(),
82+
in_data,
83+
out_data,
84+
n,
85+
c,
86+
d,
87+
h,
88+
w,
89+
padding,
90+
is_ncdhw);
91+
CHECK_EQ(r, 0);
92+
} else if (mode == "replicate") {
93+
int r = xdnn::replication_pad3d<T>(ctx.GetRawContext(),
94+
in_data,
95+
out_data,
96+
n,
97+
c,
98+
d,
99+
h,
100+
w,
101+
padding,
102+
is_ncdhw);
103+
CHECK_EQ(r, 0);
83104
} else {
84105
LOG(FATAL) << "xpu unsupport mode: " << mode;
85106
}

0 commit comments

Comments
 (0)