@@ -162,6 +162,32 @@ def elementwise_pow(name : str, x, y, axis, in_dtype):
162162
163163 return outs [0 ]
164164
165+
166+ def elementwise_floordiv (name : str , x , y , axis , in_dtype , paddle_ver = "1.8" ):
167+ import paddle
168+ paddle .enable_static ()
169+
170+ with paddle .static .program_guard (paddle .static .Program (), paddle .static .Program ()):
171+ node_x = paddle .static .data (name = 'x' , shape = x .shape , dtype = in_dtype )
172+ node_y = paddle .static .data (name = 'y' , shape = y .shape , dtype = in_dtype )
173+ if paddle_ver == "1.8" :
174+ out = paddle .fluid .layers .nn .elementwise_floordiv (node_x , node_y , axis = axis )
175+ else :
176+ out = paddle .floor_divide (node_x , node_y )
177+
178+ cpu = paddle .static .cpu_places (1 )
179+ exe = paddle .static .Executor (cpu [0 ])
180+
181+ # startup program will call initializer to initialize the parameters.
182+ exe .run (paddle .static .default_startup_program ())
183+ outs = exe .run (
184+ feed = {'x' : x , 'y' : y },
185+ fetch_list = [out ])
186+ saveModel (name , exe , feedkeys = ['x' , 'y' ], fetchlist = [out ], inputs = [x , y ], outputs = [outs [0 ]], target_dir = sys .argv [1 ])
187+
188+ return outs [0 ]
189+
190+
165191def elementwise_ops (name : str , data_x , data_y , axis , in_dtype ):
166192 elementwise_add ("elementwise_add" + name , data_x , data_y , axis , in_dtype )
167193 elementwise_sub ("elementwise_sub" + name , data_x , data_y , axis , in_dtype )
@@ -193,5 +219,39 @@ def main():
193219 axis = 0
194220 elementwise_ops ("4" , data_x , data_y , axis , in_dtype )
195221
222+ # test for elementwise_floordiv, support int and int64
223+ # paddle1.8 support axis = [0, x_last_dims]
224+ # paddle2.x only support axis = -1
225+ floordiv_support_dtype = ['int64' , 'int32' ]
226+ data_x = np .array ([- 2 , 0 , 4 ])
227+ data_y = np .array ([1 , 5 , 2 ])
228+ axis = - 1
229+ for dtype in floordiv_support_dtype :
230+ elementwise_floordiv ("elementwise_floordiv_for_paddle1.8_" + dtype + "_1" ,
231+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
232+ elementwise_floordiv ("elementwise_floordiv_for_paddle2.x_" + dtype + "_1" ,
233+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype , paddle_ver = "2.x" )
234+
235+ data_x = np .random .randint (1 , 10 , [2 , 5 , 3 , 4 ])
236+ data_y = np .random .randint (1 , 5 , [3 , 4 ])
237+ for dtype in floordiv_support_dtype :
238+ elementwise_floordiv ("elementwise_floordiv_for_paddle1.8_" + dtype + "_2" ,
239+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
240+ elementwise_floordiv ("elementwise_floordiv_for_paddle2.x_" + dtype + "_2" ,
241+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype , paddle_ver = "2.x" )
242+
243+ data_y = np .random .randint (1 , 5 , [5 ])
244+ axis = 1
245+ for dtype in floordiv_support_dtype :
246+ elementwise_floordiv ("elementwise_floordiv_for_paddle1.8_" + dtype + "_3" ,
247+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
248+
249+ data_y = np .random .randint (1 , 5 , [2 , 5 , 3 ])
250+ axis = 0
251+ for dtype in floordiv_support_dtype :
252+ elementwise_floordiv ("elementwise_floordiv_for_paddle1.8_" + dtype + "_4" ,
253+ data_x .astype (dtype ), data_y .astype (dtype ), axis , dtype )
254+
255+
196256if __name__ == "__main__" :
197257 main ()
0 commit comments