Skip to content

Commit 9499924

Browse files
fangfangssjfangfangssj
andauthored
【HEU】[Paddle Tensor 第二期 API支持 0-size Tensor] paddle.cross 支持 0-size tensor (#70103)
* fix cross * fix * fix test --------- Co-authored-by: fangfangssj <[email protected]>
1 parent 9a3f99d commit 9499924

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed

paddle/phi/kernels/cpu/cross_kernel.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ void CrossKernel(const Context& dev_ctx,
7070
"But received: Input(X/Y).dims() == [%s].",
7171
input_x_dims));
7272
}
73+
if (input_x.numel() == 0 || input_y.numel() == 0) {
74+
output->Resize(input_x.dims());
75+
dev_ctx.template Alloc<T>(output);
76+
return;
77+
}
7378
auto outer_loops = 1;
7479
for (auto i = 0; i < dim; i++) {
7580
outer_loops *= static_cast<int>(input_x_dims[i]);

paddle/phi/kernels/gpu/cross_kernel.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ void CrossKernel(const Context& dev_ctx,
101101
input_x_dims));
102102
}
103103

104+
if (input_x.numel() == 0 || input_y.numel() == 0) {
105+
output->Resize(input_x.dims());
106+
dev_ctx.template Alloc<T>(output);
107+
return;
108+
}
104109
std::vector<int> cal_dims;
105110
std::vector<int> left_strides;
106111
std::vector<int> full_strides;

test/legacy_test/test_cross_op.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ def input_data(self):
170170
self.data_y = np.array(
171171
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]
172172
).astype('float32')
173+
self.data_x_zero = np.array([]).reshape(0, 3).astype('float32')
174+
self.data_y_zero = np.array([]).reshape(0, 3).astype('float32')
173175

174176
def test_cross_api(self):
175177
self.input_data()
@@ -212,6 +214,26 @@ def test_cross_api(self):
212214
)
213215
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
214216

217+
main = paddle.static.Program()
218+
startup = paddle.static.Program()
219+
# case 3:
220+
with paddle.static.program_guard(main, startup):
221+
x = paddle.static.data(name='x', shape=[0, 3], dtype="float32")
222+
y = paddle.static.data(name='y', shape=[0, 3], dtype="float32")
223+
z = paddle.cross(x, y, axis=1)
224+
exe = base.Executor(base.CPUPlace())
225+
(res,) = exe.run(
226+
main,
227+
feed={'x': self.data_x_zero, 'y': self.data_y_zero},
228+
fetch_list=[z],
229+
return_numpy=False,
230+
)
231+
expect_out = np.empty((0, 3))
232+
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)
233+
234+
main = paddle.static.Program()
235+
startup = paddle.static.Program()
236+
215237
def test_cross_api1(self):
216238
with paddle.pir_utils.OldIrGuard():
217239
self.input_data()
@@ -227,6 +249,17 @@ def test_cross_api1(self):
227249
y_1 = paddle.cross(x, y, name='result')
228250
self.assertEqual(('result' in y_1.name), True)
229251

252+
main = paddle.static.Program()
253+
startup = paddle.static.Program()
254+
255+
# case 2:
256+
with paddle.static.program_guard(main, startup):
257+
x = paddle.static.data(name="x", shape=[0, 3], dtype="float32")
258+
y = paddle.static.data(name='y', shape=[0, 3], dtype='float32')
259+
260+
y_1 = paddle.cross(x, y, axis=1, name='result')
261+
self.assertEqual(('result' in y_1.name), True)
262+
230263
def test_dygraph_api(self):
231264
self.input_data()
232265
# case 1:
@@ -250,6 +283,15 @@ def test_dygraph_api(self):
250283
)
251284
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
252285

286+
# case 3:
287+
with base.dygraph.guard():
288+
x = paddle.to_tensor(self.data_x_zero)
289+
y = paddle.to_tensor(self.data_y_zero)
290+
z = paddle.cross(x, y, axis=1)
291+
np_z = z.numpy()
292+
expect_out = np.empty((0, 3))
293+
np.testing.assert_allclose(expect_out, np_z, rtol=1e-05)
294+
253295

254296
if __name__ == '__main__':
255297
paddle.enable_static()

0 commit comments

Comments
 (0)