@@ -121,49 +121,50 @@ def test_disable_autotune(self):
121121
122122class TestStaticAutoTuneStatus (TestAutoTune ):
123123 def run_program (self , enable_autotune ):
124- paddle .enable_static ()
125-
126- data_shape = [1 , 1 , 8 , 8 ]
127- main_program = paddle .static .Program ()
128- startup_program = paddle .static .Program ()
129- with paddle .static .program_guard (main_program , startup_program ):
130- data = paddle .static .data (
131- name = 'X' , shape = data_shape , dtype = 'float32'
124+ with paddle .pir_utils .OldIrGuard ():
125+ paddle .enable_static ()
126+
127+ data_shape = [1 , 1 , 8 , 8 ]
128+ main_program = paddle .static .Program ()
129+ startup_program = paddle .static .Program ()
130+ with paddle .static .program_guard (main_program , startup_program ):
131+ data = paddle .static .data (
132+ name = 'X' , shape = data_shape , dtype = 'float32'
133+ )
134+ net = SimpleNet ()
135+ loss = static_program (net , data )
136+ place = (
137+ paddle .CUDAPlace (0 )
138+ if paddle .base .core .is_compiled_with_cuda ()
139+ else paddle .CPUPlace ()
132140 )
133- net = SimpleNet ()
134- loss = static_program (net , data )
135- place = (
136- paddle .CUDAPlace (0 )
137- if paddle .base .core .is_compiled_with_cuda ()
138- else paddle .CPUPlace ()
139- )
140- exe = paddle .static .Executor (place )
141- exe .run (startup_program )
142- x = np .random .random (size = data_shape ).astype ('float32' )
143-
144- # Node(tizheng): warmup run to make sure the following runs
145- # are in the same thread. Necessary for CUDNNv8 tests
146- exe .run (program = main_program , feed = {'X' : x }, fetch_list = [loss ])
141+ exe = paddle .static .Executor (place )
142+ exe .run (startup_program )
143+ x = np .random .random (size = data_shape ).astype ('float32' )
147144
148- self .set_flags (enable_autotune )
149- if enable_autotune :
150- config = {"kernel" : {"enable" : True , "tuning_range" : [1 , 2 ]}}
151- tfile = tempfile .NamedTemporaryFile (mode = "w+" , delete = False )
152- json .dump (config , tfile )
153- tfile .close ()
154- paddle .incubate .autotune .set_config (tfile .name )
155- os .remove (tfile .name )
156- else :
157- paddle .incubate .autotune .set_config (
158- config = {"kernel" : {"enable" : False , "tuning_range" : [1 , 2 ]}}
159- )
160-
161- for i in range (3 ):
145+ # Node(tizheng): warmup run to make sure the following runs
146+ # are in the same thread. Necessary for CUDNNv8 tests
162147 exe .run (program = main_program , feed = {'X' : x }, fetch_list = [loss ])
163- status = paddle .base .core .autotune_status ()
164- expected_res = self .get_expected_res (i , enable_autotune )
165- self .check_status (expected_res )
166- paddle .disable_static ()
148+
149+ self .set_flags (enable_autotune )
150+ if enable_autotune :
151+ config = {"kernel" : {"enable" : True , "tuning_range" : [1 , 2 ]}}
152+ tfile = tempfile .NamedTemporaryFile (mode = "w+" , delete = False )
153+ json .dump (config , tfile )
154+ tfile .close ()
155+ paddle .incubate .autotune .set_config (tfile .name )
156+ os .remove (tfile .name )
157+ else :
158+ paddle .incubate .autotune .set_config (
159+ config = {"kernel" : {"enable" : False , "tuning_range" : [1 , 2 ]}}
160+ )
161+
162+ for i in range (3 ):
163+ exe .run (program = main_program , feed = {'X' : x }, fetch_list = [loss ])
164+ status = paddle .base .core .autotune_status ()
165+ expected_res = self .get_expected_res (i , enable_autotune )
166+ self .check_status (expected_res )
167+ paddle .disable_static ()
167168
168169 def func_enable_autotune (self ):
169170 self .run_program (enable_autotune = True )
0 commit comments