@@ -36,50 +36,71 @@ void Pad3dCompute<T>::Run() {
36
36
auto * in_data = x->template data <T>();
37
37
auto * out = param.Out ;
38
38
T* out_data = out->template mutable_data <T>(TARGET (kXPU ));
39
+ bool is_ncdhw;
40
+ int n, c, d, h, w;
41
+ if (data_format == " NCDHW" ) {
42
+ is_ncdhw = true ;
43
+ n = in_dims[0 ];
44
+ c = in_dims[1 ];
45
+ d = in_dims[2 ];
46
+ h = in_dims[3 ];
47
+ w = in_dims[4 ];
48
+ } else if (data_format == " NDHWC" ) {
49
+ is_ncdhw = false ;
50
+ n = in_dims[0 ];
51
+ c = in_dims[4 ];
52
+ d = in_dims[1 ];
53
+ h = in_dims[2 ];
54
+ w = in_dims[3 ];
55
+ } else {
56
+ LOG (FATAL) << " xpu unsupport data_format: " << data_format;
57
+ }
58
+ // trans pad format
59
+ std::vector<int > padding (6 );
60
+ padding[0 ] = pads[4 ];
61
+ padding[1 ] = pads[5 ];
62
+ padding[2 ] = pads[2 ];
63
+ padding[3 ] = pads[3 ];
64
+ padding[4 ] = pads[0 ];
65
+ padding[5 ] = pads[1 ];
39
66
40
- if (mode == " reflect" || mode == " constant" || mode == " replicate" ||
41
- mode == " circular" ) {
42
- if (data_format == " NCDHW" ) {
43
- std::vector<int > pad_left = {0 , 0 , pads[4 ], pads[2 ], pads[0 ]};
44
- std::vector<int > pad_right = {0 , 0 , pads[5 ], pads[3 ], pads[1 ]};
45
-
46
- int n_shape = in_dims[0 ];
47
- int c_shape = in_dims[1 ];
48
- int d_shape = in_dims[2 ];
49
- int h_shape = in_dims[3 ];
50
- int w_shape = in_dims[4 ];
51
-
52
- std::vector<int > xshape = {n_shape, c_shape, d_shape, h_shape, w_shape};
53
-
54
- int r = xdnn::pad<T>(ctx.GetRawContext (),
55
- in_data,
56
- out_data,
57
- xshape,
58
- pad_left,
59
- pad_right,
60
- value);
61
- CHECK_EQ (r, 0 );
62
- } else if (data_format == " NDHWC" ) {
63
- std::vector<int > pad_left = {0 , pads[4 ], pads[2 ], pads[0 ], 0 };
64
- std::vector<int > pad_right = {0 , pads[5 ], pads[3 ], pads[1 ], 0 };
65
-
66
- int n_shape = in_dims[0 ];
67
- int d_shape = in_dims[1 ];
68
- int h_shape = in_dims[2 ];
69
- int w_shape = in_dims[3 ];
70
- int c_shape = in_dims[4 ];
71
- std::vector<int > xshape = {n_shape, d_shape, h_shape, w_shape, c_shape};
72
-
73
- int r = xdnn::pad<T>(ctx.GetRawContext (),
74
- in_data,
75
- out_data,
76
- xshape,
77
- pad_left,
78
- pad_right,
79
- value);
80
- CHECK_EQ (r, 0 );
81
- }
82
-
67
+ if (mode == " constant" ) {
68
+ int r = xdnn::constant_pad3d<T>(ctx.GetRawContext (),
69
+ in_data,
70
+ out_data,
71
+ n,
72
+ c,
73
+ d,
74
+ h,
75
+ w,
76
+ padding,
77
+ value,
78
+ is_ncdhw);
79
+ CHECK_EQ (r, 0 );
80
+ } else if (mode == " reflect" ) {
81
+ int r = xdnn::reflection_pad3d<T>(ctx.GetRawContext (),
82
+ in_data,
83
+ out_data,
84
+ n,
85
+ c,
86
+ d,
87
+ h,
88
+ w,
89
+ padding,
90
+ is_ncdhw);
91
+ CHECK_EQ (r, 0 );
92
+ } else if (mode == " replicate" ) {
93
+ int r = xdnn::replication_pad3d<T>(ctx.GetRawContext (),
94
+ in_data,
95
+ out_data,
96
+ n,
97
+ c,
98
+ d,
99
+ h,
100
+ w,
101
+ padding,
102
+ is_ncdhw);
103
+ CHECK_EQ (r, 0 );
83
104
} else {
84
105
LOG (FATAL) << " xpu unsupport mode: " << mode;
85
106
}
0 commit comments