Skip to content

Commit 4d12653

Browse files
【pir_save_load】modify pir ci for test_paddle_save_load,py and test_jit_save_load.py (#65446)
* modify test_jit_save_load * modify test_jit_save_load * modify ci save_load
1 parent 4609910 commit 4d12653

File tree

4 files changed

+306
-117
lines changed

4 files changed

+306
-117
lines changed

python/paddle/pir/core.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717

18+
import paddle
1819
from paddle.base.core import Place, VarDesc
1920
from paddle.base.libpaddle import DataType
2021
from paddle.base.libpaddle.pir import (
@@ -476,3 +477,77 @@ def static_op_arg_cast_guard(hook):
476477
yield
477478
finally:
478479
set_static_op_arg_pre_cast_hook(original_callback)
480+
481+
482+
def set_state_dict(program, state_dict, scope=None):
483+
"""
484+
Set parameters and persistable buffers in state_dict to program.
485+
An exception will throw if shape or dtype of the parameters is not match.
486+
487+
.. note::
488+
This function MUST called after run start_up_program
489+
490+
Args:
491+
state_dict(dict): the dict store parameters and persistable buffers.
492+
The key is the name of the parameter or the name of the buffer.
493+
The value is the tensor of this variable in the given scope.
494+
scope(Scope, optional) : If scope is None, state_dict will be set to global scope
495+
obtained through 'paddle.static.global_scope()'. Otherwise, value will be set to scope.
496+
Default: None
497+
498+
Returns:
499+
None
500+
501+
Examples:
502+
.. code-block:: python
503+
504+
>>> import paddle
505+
>>> import paddle.static as static
506+
507+
>>> paddle.enable_static()
508+
509+
>>> x = static.data(name="x", shape=[10, 10], dtype='float32')
510+
>>> y = static.nn.fc(x, 10)
511+
>>> z = static.nn.fc(y, 10)
512+
513+
>>> place = paddle.CPUPlace()
514+
>>> exe = static.Executor(place)
515+
>>> exe.run(static.default_startup_program())
516+
>>> prog = static.default_main_program()
517+
518+
>>> path = "./temp/model.pdparams"
519+
>>> paddle.save(prog.state_dict(), path)
520+
>>> state_dict_load = paddle.load(path)
521+
>>> prog.set_state_dict(state_dict_load)
522+
"""
523+
if not isinstance(state_dict, dict):
524+
raise TypeError(
525+
f"Type of `state_dict` should be dict, but received {type(state_dict)}."
526+
)
527+
528+
condition = True if "StructuredToParameterName@@" in state_dict else False
529+
if condition:
530+
clear_state_dict = {}
531+
for name, value in state_dict.items():
532+
if name == "StructuredToParameterName@@":
533+
continue
534+
if name in state_dict["StructuredToParameterName@@"]:
535+
name = state_dict["StructuredToParameterName@@"][name]
536+
clear_state_dict[name] = value
537+
else:
538+
clear_state_dict[name] = value
539+
else:
540+
clear_state_dict = state_dict
541+
542+
for name, value in clear_state_dict.items():
543+
if isinstance(value, paddle.base.libpaddle.Tensor):
544+
continue
545+
elif isinstance(value, np.ndarray):
546+
clear_state_dict[name] = paddle.to_tensor(value)
547+
else:
548+
raise TypeError(
549+
f"The type of `{name}` should be Tensor, ndarray, but received {type(value)}."
550+
)
551+
if scope is None:
552+
scope = paddle.static.global_scope()
553+
program.set_state_dict(clear_state_dict, scope)
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import pickle
17+
import tempfile
18+
import unittest
19+
from io import BytesIO
20+
21+
import numpy as np
22+
from test_imperative_base import new_program_scope
23+
24+
import paddle
25+
from paddle import base, nn
26+
from paddle.jit.api import to_static
27+
from paddle.jit.translated_layer import INFER_PARAMS_INFO_SUFFIX
28+
from paddle.nn import Linear
29+
from paddle.static import InputSpec
30+
31+
IMAGE_SIZE = 784
32+
CLASS_NUM = 10
33+
34+
SEED = 10
35+
36+
37+
class LinearNet(nn.Layer):
38+
def __init__(self):
39+
super().__init__()
40+
self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM)
41+
42+
def forward(self, x):
43+
return self._linear(x)
44+
45+
46+
class LinearNetReturnHidden(paddle.nn.Layer):
47+
def __init__(self, in_size, out_size):
48+
super().__init__()
49+
self._linear_1 = Linear(in_size, out_size)
50+
self._linear_2 = Linear(in_size, out_size)
51+
52+
@to_static
53+
def forward(self, x):
54+
y = self._linear_1(x)
55+
z = self._linear_2(y)
56+
loss = paddle.mean(z)
57+
return y, loss
58+
59+
60+
class TestSaveLoadProgram(unittest.TestCase):
61+
def test_save_load_program(self):
62+
paddle.enable_static()
63+
temp_dir = tempfile.TemporaryDirectory()
64+
65+
with new_program_scope():
66+
layer = LinearNet()
67+
data = paddle.static.data(
68+
name='x_static_save', shape=(None, IMAGE_SIZE), dtype='float32'
69+
)
70+
y_static = layer(data)
71+
main_program = paddle.static.default_main_program()
72+
startup_program = paddle.static.default_startup_program()
73+
origin_main = main_program.desc.serialize_to_string()
74+
origin_startup = startup_program.desc.serialize_to_string()
75+
path1 = os.path.join(
76+
temp_dir.name,
77+
"test_paddle_save_load_program/main_program.pdmodel",
78+
)
79+
path2 = os.path.join(
80+
temp_dir.name,
81+
"test_paddle_save_load_program/startup_program.pdmodel",
82+
)
83+
paddle.save(main_program, path1)
84+
paddle.save(startup_program, path2)
85+
86+
with new_program_scope():
87+
load_main = paddle.load(path1).desc.serialize_to_string()
88+
load_startup = paddle.load(path2).desc.serialize_to_string()
89+
self.assertTrue(origin_main == load_main)
90+
self.assertTrue(origin_startup == load_startup)
91+
temp_dir.cleanup()
92+
93+
94+
class TestJitPruneModelAndLoad(unittest.TestCase):
95+
def setUp(self):
96+
self.linear_size = 4
97+
self.temp_dir = tempfile.TemporaryDirectory()
98+
self.model_path = os.path.join(
99+
self.temp_dir.name, "jit_prune_model_and_load/model"
100+
)
101+
# enable dygraph mode
102+
base.enable_dygraph()
103+
# config seed
104+
paddle.seed(SEED)
105+
paddle.framework.random._manual_program_seed(SEED)
106+
107+
def tearDown(self):
108+
self.temp_dir.cleanup()
109+
110+
def train_and_save(self):
111+
train_layer = LinearNetReturnHidden(8, 8)
112+
train_layer = to_static(
113+
train_layer,
114+
input_spec=[InputSpec([None, 8], name='x')],
115+
full_graph=True,
116+
)
117+
adam = paddle.optimizer.Adam(
118+
learning_rate=0.1, parameters=train_layer.parameters()
119+
)
120+
x = paddle.to_tensor(np.random.random((4, 8)).astype('float32'))
121+
for i in range(10):
122+
hidden, loss = train_layer(x)
123+
loss.backward()
124+
adam.minimize(loss)
125+
train_layer.clear_gradients()
126+
127+
output_spec = train_layer.forward.outputs[:1]
128+
paddle.jit.save(
129+
layer=train_layer,
130+
path=self.model_path,
131+
input_spec=[x],
132+
output_spec=output_spec,
133+
)
134+
135+
return train_layer
136+
137+
# pir has no need to save extra var info, param always saved with program,
138+
# and trainable info saved in program's op attr
139+
def test_load_var_not_in_extra_var_info(self):
140+
self.train_and_save()
141+
142+
# chage extra var info
143+
var_info_path = self.model_path + INFER_PARAMS_INFO_SUFFIX
144+
with open(var_info_path, 'rb') as f:
145+
extra_var_info = pickle.load(f)
146+
extra_var_info.clear()
147+
with open(var_info_path, 'wb') as f:
148+
pickle.dump(extra_var_info, f, protocol=2)
149+
150+
with self.assertRaises(RuntimeError):
151+
paddle.jit.load(self.model_path)
152+
153+
154+
class TestSaveLoadToMemory(unittest.TestCase):
155+
def test_static_save_to_memory(self):
156+
paddle.enable_static()
157+
with new_program_scope():
158+
# create network
159+
x = paddle.static.data(
160+
name="x", shape=[None, IMAGE_SIZE], dtype='float32'
161+
)
162+
z = paddle.static.nn.fc(x, 10, bias_attr=False)
163+
z = paddle.static.nn.fc(z, 128, bias_attr=False)
164+
loss = paddle.mean(z)
165+
place = (
166+
base.CPUPlace()
167+
if not paddle.base.core.is_compiled_with_cuda()
168+
else base.CUDAPlace(0)
169+
)
170+
prog = paddle.static.default_main_program()
171+
exe = paddle.static.Executor(place)
172+
exe.run(paddle.static.default_startup_program())
173+
174+
state_dict = prog.state_dict()
175+
keys = list(state_dict.keys())
176+
tensor = state_dict[keys[0]]
177+
178+
byio = BytesIO()
179+
byio2 = BytesIO()
180+
paddle.save(prog, byio2)
181+
paddle.save(tensor, byio)
182+
paddle.save(state_dict, byio)
183+
byio.seek(0)
184+
byio2.seek(0)
185+
186+
prog_load = paddle.load(byio2)
187+
self.assertTrue(
188+
prog.desc.serialize_to_string()
189+
== prog_load.desc.serialize_to_string()
190+
)
191+
192+
tensor_load = paddle.load(byio, return_numpy=True)
193+
np.testing.assert_array_equal(tensor_load, np.array(tensor))
194+
195+
state_dict_load = paddle.load(byio, return_numpy=True)
196+
for k, v in state_dict.items():
197+
np.testing.assert_array_equal(np.array(v), state_dict_load[k])
198+
199+
200+
if __name__ == '__main__':
201+
unittest.main()

test/legacy_test/test_jit_save_load.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import copy
1717
import os
18-
import pickle
1918
import shutil
2019
import tempfile
2120
import unittest
@@ -27,7 +26,6 @@
2726
from paddle import base
2827
from paddle.base import unique_name
2928
from paddle.jit.api import to_static
30-
from paddle.jit.translated_layer import INFER_PARAMS_INFO_SUFFIX
3129
from paddle.nn import Linear
3230
from paddle.pir_utils import test_with_dygraph_pir
3331
from paddle.static import InputSpec
@@ -1112,22 +1110,6 @@ def test_load_pruned_model(self):
11121110
train_layer(x)[0].numpy(), infer_layer(x).numpy()
11131111
)
11141112

1115-
# pir has no need to save extra var info, param always saved with program,
1116-
# and trainable info saved in program's op attr
1117-
def test_load_var_not_in_extra_var_info(self):
1118-
self.train_and_save()
1119-
1120-
# chage extra var info
1121-
var_info_path = self.model_path + INFER_PARAMS_INFO_SUFFIX
1122-
with open(var_info_path, 'rb') as f:
1123-
extra_var_info = pickle.load(f)
1124-
extra_var_info.clear()
1125-
with open(var_info_path, 'wb') as f:
1126-
pickle.dump(extra_var_info, f, protocol=2)
1127-
1128-
with self.assertRaises(RuntimeError):
1129-
paddle.jit.load(self.model_path)
1130-
11311113

11321114
class TestJitSaveMultiCases(unittest.TestCase):
11331115
def setUp(self):

0 commit comments

Comments
 (0)