@@ -170,6 +170,8 @@ def input_data(self):
170170 self .data_y = np .array (
171171 [[1.0 , 1.0 , 1.0 ], [1.0 , 1.0 , 1.0 ], [1.0 , 1.0 , 1.0 ]]
172172 ).astype ('float32' )
173+ self .data_x_zero = np .array ([]).reshape (0 , 3 ).astype ('float32' )
174+ self .data_y_zero = np .array ([]).reshape (0 , 3 ).astype ('float32' )
173175
174176 def test_cross_api (self ):
175177 self .input_data ()
@@ -212,6 +214,26 @@ def test_cross_api(self):
212214 )
213215 np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
214216
217+ main = paddle .static .Program ()
218+ startup = paddle .static .Program ()
219+ # case 3:
220+ with paddle .static .program_guard (main , startup ):
221+ x = paddle .static .data (name = 'x' , shape = [0 , 3 ], dtype = "float32" )
222+ y = paddle .static .data (name = 'y' , shape = [0 , 3 ], dtype = "float32" )
223+ z = paddle .cross (x , y , axis = 1 )
224+ exe = base .Executor (base .CPUPlace ())
225+ (res ,) = exe .run (
226+ main ,
227+ feed = {'x' : self .data_x_zero , 'y' : self .data_y_zero },
228+ fetch_list = [z ],
229+ return_numpy = False ,
230+ )
231+ expect_out = np .empty ((0 , 3 ))
232+ np .testing .assert_allclose (expect_out , np .array (res ), rtol = 1e-05 )
233+
234+ main = paddle .static .Program ()
235+ startup = paddle .static .Program ()
236+
215237 def test_cross_api1 (self ):
216238 with paddle .pir_utils .OldIrGuard ():
217239 self .input_data ()
@@ -227,6 +249,17 @@ def test_cross_api1(self):
227249 y_1 = paddle .cross (x , y , name = 'result' )
228250 self .assertEqual (('result' in y_1 .name ), True )
229251
252+ main = paddle .static .Program ()
253+ startup = paddle .static .Program ()
254+
255+ # case 2:
256+ with paddle .static .program_guard (main , startup ):
257+ x = paddle .static .data (name = "x" , shape = [0 , 3 ], dtype = "float32" )
258+ y = paddle .static .data (name = 'y' , shape = [0 , 3 ], dtype = 'float32' )
259+
260+ y_1 = paddle .cross (x , y , axis = 1 , name = 'result' )
261+ self .assertEqual (('result' in y_1 .name ), True )
262+
230263 def test_dygraph_api (self ):
231264 self .input_data ()
232265 # case 1:
@@ -250,6 +283,15 @@ def test_dygraph_api(self):
250283 )
251284 np .testing .assert_allclose (expect_out , np_z , rtol = 1e-05 )
252285
286+ # case 3:
287+ with base .dygraph .guard ():
288+ x = paddle .to_tensor (self .data_x_zero )
289+ y = paddle .to_tensor (self .data_y_zero )
290+ z = paddle .cross (x , y , axis = 1 )
291+ np_z = z .numpy ()
292+ expect_out = np .empty ((0 , 3 ))
293+ np .testing .assert_allclose (expect_out , np_z , rtol = 1e-05 )
294+
253295
254296if __name__ == '__main__' :
255297 paddle .enable_static ()
0 commit comments