File tree Expand file tree Collapse file tree 3 files changed +29
-2
lines changed Expand file tree Collapse file tree 3 files changed +29
-2
lines changed Original file line number Diff line number Diff line change @@ -1215,7 +1215,8 @@ set(TEST_CINN_OPS
12151215    test_elementwise_pow_op
12161216    test_transpose_op
12171217    test_reshape_op
1218-     test_mean_op)
1218+     test_mean_op
1219+     test_unsqueeze2_op)
12191220
12201221foreach (TEST_CINN_OPS ${TEST_CINN_OPS} )
12211222  if (WITH_CINN)
Original file line number Diff line number Diff line change @@ -36,9 +36,12 @@ def setUp(self):
3636            "Out" : self .inputs ["X" ].reshape (self .new_shape ),
3737            "XShape" : np .random .random (self .ori_shape ).astype ("float64" ),
3838        }
39+         self .prim_op_type  =  "comp" 
3940
4041    def  test_check_output (self ):
41-         self .check_output (no_check_set = ["XShape" ], check_eager = True )
42+         self .check_output (
43+             no_check_set = ["XShape" ], check_eager = True , check_prim = True 
44+         )
4245
4346    def  test_check_grad (self ):
4447        self .check_grad (["X" ], "Out" , check_eager = True )
@@ -89,20 +92,23 @@ def init_test_case(self):
8992        self .ori_shape  =  ()
9093        self .axes  =  (- 1 ,)
9194        self .new_shape  =  1 
95+         self .enable_cinn  =  False 
9296
9397
9498class  TestUnsqueezeOp_ZeroDim2 (TestUnsqueezeOp ):
9599    def  init_test_case (self ):
96100        self .ori_shape  =  ()
97101        self .axes  =  (- 1 , 1 )
98102        self .new_shape  =  (1 , 1 )
103+         self .enable_cinn  =  False 
99104
100105
101106class  TestUnsqueezeOp_ZeroDim3 (TestUnsqueezeOp ):
102107    def  init_test_case (self ):
103108        self .ori_shape  =  ()
104109        self .axes  =  (0 , 1 , 2 )
105110        self .new_shape  =  (1 , 1 , 1 )
111+         self .enable_cinn  =  False 
106112
107113
108114# axes is a list(with tensor) 
Original file line number Diff line number Diff line change @@ -371,3 +371,23 @@ def relu_composite(x):
371371    """define composite rule of op relu.""" 
372372    # relu(x) = max(x, 0) 
373373    return  maximum (x , zeros_like (x ))
374+ 
375+ 
376+ @REGISTER_COMPOSITE ('unsqueeze2' ) 
377+ def  unsqueeze_composite (x , axis ):
378+     """define composite rule of op unsqueeze""" 
379+     """using reshape to implement unsqueeze op""" 
380+     x_shape  =  list (x .shape )
381+     axis_list  =  list (axis )
382+     for  i  in  axis_list :
383+         if  i  <  0 :
384+             i  +=  len (x_shape ) +  1 
385+         x_shape  =  (
386+             x_shape [:i ]
387+             +  [
388+                 1 ,
389+             ]
390+             +  x_shape [i :]
391+         )
392+     out  =  reshape (x , x_shape )
393+     return  [out , None ]
    
 
   
 
     
   
   
          
     
  
    
     
 
    
      
     
 
     
    You can’t perform that action at this time.
  
 
    
  
     
    
      
        
     
 
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments