Skip to content

Commit e2d45fe

Browse files
wanghuancoderlixcli
authored andcommitted
[PIR] fix some pir test case (PaddlePaddle#65998)
* fix some pir test case * refine
1 parent 8960332 commit e2d45fe

File tree

3 files changed

+69
-45
lines changed

3 files changed

+69
-45
lines changed

paddle/fluid/prim/api/auto_code_generated/eager_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def parse_input_and_attr(self, inputs_list, attrs_list):
158158
'Scalar(int)': 'const Scalar&',
159159
'Scalar(int64_t)': 'const Scalar&',
160160
'Scalar(float)': 'const Scalar&',
161-
'Scalar(dobule)': 'const Scalar&',
161+
'Scalar(double)': 'const Scalar&',
162162
'Scalar[]': 'const std::vector<phi::Scalar>&',
163163
'int': 'int',
164164
'int32_t': 'int32_t',

test/deprecated/legacy_test/test_switch_autotune.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -121,49 +121,50 @@ def test_disable_autotune(self):
121121

122122
class TestStaticAutoTuneStatus(TestAutoTune):
123123
def run_program(self, enable_autotune):
124-
paddle.enable_static()
125-
126-
data_shape = [1, 1, 8, 8]
127-
main_program = paddle.static.Program()
128-
startup_program = paddle.static.Program()
129-
with paddle.static.program_guard(main_program, startup_program):
130-
data = paddle.static.data(
131-
name='X', shape=data_shape, dtype='float32'
124+
with paddle.pir_utils.OldIrGuard():
125+
paddle.enable_static()
126+
127+
data_shape = [1, 1, 8, 8]
128+
main_program = paddle.static.Program()
129+
startup_program = paddle.static.Program()
130+
with paddle.static.program_guard(main_program, startup_program):
131+
data = paddle.static.data(
132+
name='X', shape=data_shape, dtype='float32'
133+
)
134+
net = SimpleNet()
135+
loss = static_program(net, data)
136+
place = (
137+
paddle.CUDAPlace(0)
138+
if paddle.base.core.is_compiled_with_cuda()
139+
else paddle.CPUPlace()
132140
)
133-
net = SimpleNet()
134-
loss = static_program(net, data)
135-
place = (
136-
paddle.CUDAPlace(0)
137-
if paddle.base.core.is_compiled_with_cuda()
138-
else paddle.CPUPlace()
139-
)
140-
exe = paddle.static.Executor(place)
141-
exe.run(startup_program)
142-
x = np.random.random(size=data_shape).astype('float32')
143-
144-
# Node(tizheng): warmup run to make sure the following runs
145-
# are in the same thread. Necessary for CUDNNv8 tests
146-
exe.run(program=main_program, feed={'X': x}, fetch_list=[loss])
141+
exe = paddle.static.Executor(place)
142+
exe.run(startup_program)
143+
x = np.random.random(size=data_shape).astype('float32')
147144

148-
self.set_flags(enable_autotune)
149-
if enable_autotune:
150-
config = {"kernel": {"enable": True, "tuning_range": [1, 2]}}
151-
tfile = tempfile.NamedTemporaryFile(mode="w+", delete=False)
152-
json.dump(config, tfile)
153-
tfile.close()
154-
paddle.incubate.autotune.set_config(tfile.name)
155-
os.remove(tfile.name)
156-
else:
157-
paddle.incubate.autotune.set_config(
158-
config={"kernel": {"enable": False, "tuning_range": [1, 2]}}
159-
)
160-
161-
for i in range(3):
145+
# Node(tizheng): warmup run to make sure the following runs
146+
# are in the same thread. Necessary for CUDNNv8 tests
162147
exe.run(program=main_program, feed={'X': x}, fetch_list=[loss])
163-
status = paddle.base.core.autotune_status()
164-
expected_res = self.get_expected_res(i, enable_autotune)
165-
self.check_status(expected_res)
166-
paddle.disable_static()
148+
149+
self.set_flags(enable_autotune)
150+
if enable_autotune:
151+
config = {"kernel": {"enable": True, "tuning_range": [1, 2]}}
152+
tfile = tempfile.NamedTemporaryFile(mode="w+", delete=False)
153+
json.dump(config, tfile)
154+
tfile.close()
155+
paddle.incubate.autotune.set_config(tfile.name)
156+
os.remove(tfile.name)
157+
else:
158+
paddle.incubate.autotune.set_config(
159+
config={"kernel": {"enable": False, "tuning_range": [1, 2]}}
160+
)
161+
162+
for i in range(3):
163+
exe.run(program=main_program, feed={'X': x}, fetch_list=[loss])
164+
status = paddle.base.core.autotune_status()
165+
expected_res = self.get_expected_res(i, enable_autotune)
166+
self.check_status(expected_res)
167+
paddle.disable_static()
167168

168169
def func_enable_autotune(self):
169170
self.run_program(enable_autotune=True)

test/deprecated/legacy_test/test_transformer_api.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,15 @@ class TestTransformer(unittest.TestCase):
283283
def test_multi_head_attention(self):
284284
def multihead_attention_test_helper(self_attention, cache):
285285
paddle.seed(2020)
286-
paddle.framework.random._manual_program_seed(2020)
287-
# self_attention|cross_attention, cache|No cache
286+
if paddle.framework.use_pir_api():
287+
with paddle.pir_utils.OldIrGuard():
288+
# Note: dygraph use self.main_program.global_block().create_parameter(), it's need manual seed to old Program
289+
paddle.framework.random._manual_program_seed(2020)
290+
paddle.framework.random._manual_program_seed(2020)
291+
else:
292+
paddle.framework.random._manual_program_seed(
293+
2020
294+
) # self_attention|cross_attention, cache|No cache
288295
with base.dygraph.guard(base.CPUPlace()):
289296
# generate params for multi_head_attention
290297
(
@@ -401,7 +408,15 @@ def multihead_attention_test_helper(self_attention, cache):
401408
def test_transformer_encoder_layer(self):
402409
with base.dygraph.guard(base.CPUPlace()):
403410
paddle.framework.seed(2020)
404-
paddle.framework.random._manual_program_seed(2020)
411+
if paddle.framework.use_pir_api():
412+
with paddle.pir_utils.OldIrGuard():
413+
# Note: dygraph use self.main_program.global_block().create_parameter(), it's need manual seed to old Program
414+
paddle.framework.random._manual_program_seed(2020)
415+
paddle.framework.random._manual_program_seed(2020)
416+
else:
417+
paddle.framework.random._manual_program_seed(
418+
2020
419+
) # self_attention|cross_attention, cache|No cache
405420

406421
ffn_fc1_act = "relu"
407422
# 1.generate basic params
@@ -466,7 +481,15 @@ def test_transformer_encoder_layer(self):
466481
def test_transformer_encoder_layer_attr_1(self):
467482
with base.dygraph.guard(base.CPUPlace()):
468483
paddle.framework.seed(2020)
469-
paddle.framework.random._manual_program_seed(2020)
484+
if paddle.framework.use_pir_api():
485+
with paddle.pir_utils.OldIrGuard():
486+
# Note: dygraph use self.main_program.global_block().create_parameter(), it's need manual seed to old Program
487+
paddle.framework.random._manual_program_seed(2020)
488+
paddle.framework.random._manual_program_seed(2020)
489+
else:
490+
paddle.framework.random._manual_program_seed(
491+
2020
492+
) # self_attention|cross_attention, cache|No cache
470493

471494
ffn_fc1_act = "relu"
472495
# 1.generate basic params

0 commit comments

Comments
 (0)