Skip to content

Commit a43a11f

Browse files
authored
[CINN] Disable some op symbol_infer check in same_operand_with_result.cc (#68841)
* disable check_symbol_infer * Fix * Open full_like op sym_infer_check
1 parent 3f71fa3 commit a43a11f

23 files changed

+158
-65
lines changed

test/legacy_test/test_activation_op.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def setUp(self):
130130

131131
def test_check_output(self):
132132
self.check_output(
133-
check_pir=True, check_pir_onednn=self.check_pir_onednn
133+
check_pir=True,
134+
check_pir_onednn=self.check_pir_onednn,
135+
check_symbol_infer=False,
134136
)
135137

136138
def test_check_grad(self):
@@ -186,7 +188,9 @@ def setUp(self):
186188

187189
def test_check_output(self):
188190
self.check_output(
189-
check_pir=True, check_pir_onednn=self.check_pir_onednn
191+
check_pir=True,
192+
check_pir_onednn=self.check_pir_onednn,
193+
check_symbol_infer=False,
190194
)
191195

192196
def test_check_grad(self):
@@ -279,7 +283,9 @@ def test_check_grad(self):
279283

280284
def test_check_output(self):
281285
self.check_output(
282-
check_pir=True, check_pir_onednn=self.check_pir_onednn
286+
check_pir=True,
287+
check_pir_onednn=self.check_pir_onednn,
288+
check_symbol_infer=False,
283289
)
284290

285291

@@ -956,7 +962,9 @@ def setUp(self):
956962

957963
def test_check_output(self):
958964
self.check_output(
959-
check_pir=True, check_pir_onednn=self.check_pir_onednn
965+
check_pir=True,
966+
check_pir_onednn=self.check_pir_onednn,
967+
check_symbol_infer=False,
960968
)
961969

962970
def test_check_grad(self):
@@ -1158,7 +1166,9 @@ def setUp(self):
11581166

11591167
def test_check_output(self):
11601168
self.check_output(
1161-
check_pir=True, check_pir_onednn=self.check_pir_onednn
1169+
check_pir=True,
1170+
check_pir_onednn=self.check_pir_onednn,
1171+
check_symbol_infer=False,
11621172
)
11631173

11641174
def test_check_grad(self):
@@ -1914,6 +1924,7 @@ def test_check_output(self):
19141924
self.check_output(
19151925
check_pir=True,
19161926
check_pir_onednn=self.check_pir_onednn,
1927+
check_symbol_infer=False,
19171928
)
19181929

19191930
def test_check_grad(self):
@@ -1969,7 +1980,9 @@ def if_enable_cinn(self):
19691980

19701981
def test_check_output(self):
19711982
self.check_output(
1972-
check_pir=True, check_pir_onednn=self.check_pir_onednn
1983+
check_pir=True,
1984+
check_pir_onednn=self.check_pir_onednn,
1985+
check_symbol_infer=False,
19731986
)
19741987

19751988
def test_check_grad(self):
@@ -2010,7 +2023,9 @@ def init_shape(self):
20102023

20112024
def test_check_output(self):
20122025
self.check_output(
2013-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2026+
check_pir=True,
2027+
check_pir_onednn=self.check_pir_onednn,
2028+
check_symbol_infer=False,
20142029
)
20152030

20162031
# The same reason with TestFloor
@@ -2049,7 +2064,9 @@ def if_enable_cinn(self):
20492064

20502065
def test_check_output(self):
20512066
self.check_output(
2052-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2067+
check_pir=True,
2068+
check_pir_onednn=self.check_pir_onednn,
2069+
check_symbol_infer=False,
20532070
)
20542071

20552072
# the gradient on floor, ceil, round is undefined.
@@ -2107,7 +2124,9 @@ def init_shape(self):
21072124

21082125
def test_check_output(self):
21092126
self.check_output(
2110-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2127+
check_pir=True,
2128+
check_pir_onednn=self.check_pir_onednn,
2129+
check_symbol_infer=False,
21112130
)
21122131

21132132
def test_check_grad(self):
@@ -2284,7 +2303,9 @@ def init_shape(self):
22842303

22852304
def test_check_output(self):
22862305
self.check_output(
2287-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2306+
check_pir=True,
2307+
check_pir_onednn=self.check_pir_onednn,
2308+
check_symbol_infer=False,
22882309
)
22892310

22902311
def test_check_grad(self):
@@ -2412,7 +2433,9 @@ def init_shape(self):
24122433

24132434
def test_check_output(self):
24142435
self.check_output(
2415-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2436+
check_pir=True,
2437+
check_pir_onednn=self.check_pir_onednn,
2438+
check_symbol_infer=False,
24162439
)
24172440

24182441
def test_check_grad(self):
@@ -2463,7 +2486,9 @@ def init_shape(self):
24632486

24642487
def test_check_output(self):
24652488
self.check_output(
2466-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2489+
check_pir=True,
2490+
check_pir_onednn=self.check_pir_onednn,
2491+
check_symbol_infer=False,
24672492
)
24682493

24692494
def test_check_grad(self):
@@ -2527,7 +2552,9 @@ def init_shape(self):
25272552

25282553
def test_check_output(self):
25292554
self.check_output(
2530-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2555+
check_pir=True,
2556+
check_pir_onednn=self.check_pir_onednn,
2557+
check_symbol_infer=False,
25312558
)
25322559

25332560
def test_check_grad(self):
@@ -2591,7 +2618,9 @@ def init_shape(self):
25912618

25922619
def test_check_output(self):
25932620
self.check_output(
2594-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2621+
check_pir=True,
2622+
check_pir_onednn=self.check_pir_onednn,
2623+
check_symbol_infer=False,
25952624
)
25962625

25972626
def test_check_grad(self):
@@ -2642,7 +2671,9 @@ def init_decimals(self):
26422671

26432672
def test_check_output(self):
26442673
self.check_output(
2645-
check_pir=True, check_pir_onednn=self.check_pir_onednn
2674+
check_pir=True,
2675+
check_pir_onednn=self.check_pir_onednn,
2676+
check_symbol_infer=False,
26462677
)
26472678

26482679
def test_check_grad(self):
@@ -2719,6 +2750,7 @@ def test_check_output(self):
27192750
check_pir=True,
27202751
check_prim_pir=True,
27212752
check_pir_onednn=self.check_pir_onednn,
2753+
check_symbol_infer=False,
27222754
)
27232755

27242756
def if_enable_cinn(self):
@@ -3000,6 +3032,7 @@ def test_check_output(self):
30003032
check_pir=True,
30013033
check_prim_pir=False,
30023034
check_pir_onednn=self.check_pir_onednn,
3035+
check_symbol_infer=False,
30033036
)
30043037

30053038
def test_check_grad(self):
@@ -3052,6 +3085,7 @@ def test_check_output(self):
30523085
check_pir=True,
30533086
check_prim_pir=True,
30543087
check_pir_onednn=self.check_pir_onednn,
3088+
check_symbol_infer=False,
30553089
)
30563090

30573091
def test_check_grad(self):
@@ -3160,7 +3194,9 @@ def setUp(self):
31603194

31613195
def test_check_output(self):
31623196
self.check_output(
3163-
check_pir=True, check_pir_onednn=self.check_pir_onednn
3197+
check_pir=True,
3198+
check_pir_onednn=self.check_pir_onednn,
3199+
check_symbol_infer=False,
31643200
)
31653201

31663202
def test_check_grad(self):
@@ -3206,6 +3242,7 @@ def test_check_output(self):
32063242
check_pir=True,
32073243
check_prim_pir=True,
32083244
check_pir_onednn=self.check_pir_onednn,
3245+
check_symbol_infer=False,
32093246
)
32103247

32113248
def test_check_grad(self):
@@ -3557,7 +3594,9 @@ def test_check_grad(self):
35573594

35583595
def test_check_output(self):
35593596
self.check_output(
3560-
check_prim_pir=True, check_pir_onednn=self.check_pir_onednn
3597+
check_prim_pir=True,
3598+
check_pir_onednn=self.check_pir_onednn,
3599+
check_symbol_infer=False,
35613600
)
35623601

35633602
def get_alpha(self):
@@ -3803,6 +3842,7 @@ def test_check_output(self):
38033842
check_pir=True,
38043843
check_prim_pir=True,
38053844
check_pir_onednn=self.check_pir_onednn,
3845+
check_symbol_infer=False,
38063846
)
38073847

38083848

@@ -3861,7 +3901,9 @@ def if_enable_cinn(self):
38613901

38623902
def test_check_output(self):
38633903
self.check_output(
3864-
check_pir=True, check_pir_onednn=self.check_pir_onednn
3904+
check_pir=True,
3905+
check_pir_onednn=self.check_pir_onednn,
3906+
check_symbol_infer=False,
38653907
)
38663908

38673909
def test_check_grad(self):
@@ -4108,7 +4150,9 @@ def setUp(self):
41084150

41094151
def test_check_output(self):
41104152
self.check_output(
4111-
check_pir=True, check_pir_onednn=self.check_pir_onednn
4153+
check_pir=True,
4154+
check_pir_onednn=self.check_pir_onednn,
4155+
check_symbol_infer=False,
41124156
)
41134157

41144158
def test_check_grad(self):
@@ -4227,7 +4271,9 @@ def setUp(self):
42274271

42284272
def test_check_output(self):
42294273
self.check_output(
4230-
check_pir=True, check_pir_onednn=self.check_pir_onednn
4274+
check_pir=True,
4275+
check_pir_onednn=self.check_pir_onednn,
4276+
check_symbol_infer=False,
42314277
)
42324278

42334279
def test_check_grad(self):
@@ -4502,6 +4548,7 @@ def test_check_output(self):
45024548
check_prim_pir=True,
45034549
check_pir=True,
45044550
check_pir_onednn=self.check_pir_onednn,
4551+
check_symbol_infer=False,
45054552
)
45064553

45074554
def test_check_grad(self):
@@ -5388,7 +5435,9 @@ def init_shape(self):
53885435

53895436
def test_check_output(self):
53905437
self.check_output(
5391-
check_pir=True, check_pir_onednn=self.check_pir_onednn
5438+
check_pir=True,
5439+
check_pir_onednn=self.check_pir_onednn,
5440+
check_symbol_infer=False,
53925441
)
53935442

53945443
def test_check_grad(self):

test/legacy_test/test_angle_op.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def setUp(self):
4949
self.outputs = {'Out': out_ref}
5050

5151
def test_check_output(self):
52-
self.check_output(check_pir=True)
52+
self.check_output(check_pir=True, check_symbol_infer=False)
5353

5454
def test_check_grad(self):
5555
self.check_grad(
@@ -94,7 +94,9 @@ def setUp(self):
9494
self.place = core.CUDAPlace(0)
9595

9696
def test_check_output(self):
97-
self.check_output_with_place(self.place, check_pir=True)
97+
self.check_output_with_place(
98+
self.place, check_pir=True, check_symbol_infer=False
99+
)
98100

99101
def test_check_grad(self):
100102
self.check_grad_with_place(
@@ -121,7 +123,7 @@ def setUp(self):
121123
self.outputs = {'Out': out_ref}
122124

123125
def test_check_output(self):
124-
self.check_output(check_pir=True)
126+
self.check_output(check_pir=True, check_symbol_infer=False)
125127

126128
def test_check_grad(self):
127129
self.check_grad(

test/legacy_test/test_bitwise_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,9 @@ def setUp(self):
369369
self.outputs = {'Out': out}
370370

371371
def test_check_output(self):
372-
self.check_output(check_cinn=True, check_pir=True)
372+
self.check_output(
373+
check_cinn=True, check_pir=True, check_symbol_infer=False
374+
)
373375

374376
def test_check_grad(self):
375377
pass

test/legacy_test/test_clip_op.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def setUp(self):
5858
def test_check_output(self):
5959
paddle.enable_static()
6060
self.check_output(
61-
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
61+
check_cinn=self.check_cinn,
62+
check_pir=True,
63+
check_prim_pir=True,
64+
check_symbol_infer=False,
6265
)
6366
paddle.disable_static()
6467

@@ -201,7 +204,10 @@ def test_check_output(self):
201204
place = paddle.CUDAPlace(0)
202205
paddle.enable_static()
203206
self.check_output_with_place(
204-
place, check_pir=True, check_prim_pir=True
207+
place,
208+
check_pir=True,
209+
check_prim_pir=True,
210+
check_symbol_infer=False,
205211
)
206212
paddle.disable_static()
207213

test/legacy_test/test_conj_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def init_input_output(self):
5151
self.outputs = {'Out': out}
5252

5353
def test_check_output(self):
54-
self.check_output(check_pir=True)
54+
self.check_output(check_pir=True, check_symbol_infer=False)
5555

5656
def test_check_grad_normal(self):
5757
self.check_grad(
@@ -180,7 +180,9 @@ def init_input_output(self):
180180

181181
def test_check_output(self):
182182
place = core.CUDAPlace(0)
183-
self.check_output_with_place(place, check_pir=True)
183+
self.check_output_with_place(
184+
place, check_pir=True, check_symbol_infer=False
185+
)
184186

185187
def test_check_grad(self):
186188
place = core.CUDAPlace(0)

test/legacy_test/test_digamma_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def init_dtype_type(self):
4343
self.dtype = np.float64
4444

4545
def test_check_output(self):
46-
self.check_output(check_pir=True)
46+
self.check_output(check_pir=True, check_symbol_infer=False)
4747

4848
def test_check_grad_normal(self):
4949
self.check_grad(['X'], 'Out', check_pir=True)
@@ -88,7 +88,9 @@ def init_dtype_type(self):
8888

8989
def test_check_output(self):
9090
# bfloat16 needs to set the parameter place
91-
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)
91+
self.check_output_with_place(
92+
core.CUDAPlace(0), check_pir=True, check_symbol_infer=False
93+
)
9294

9395
def test_check_grad_normal(self):
9496
self.check_grad_with_place(

0 commit comments

Comments
 (0)