1818import copy
1919import paddle
2020from paddle .fluid .dygraph import guard
21- from paddle .fluid .framework import default_main_program
21+ from paddle .fluid .framework import default_main_program , Variable
2222import paddle .fluid .core as core
2323from paddle .fluid .executor import Executor
2424import paddle .fluid .io as io
2525from paddle .fluid .initializer import ConstantInitializer
2626import numpy as np
2727
28+ paddle .enable_static ()
2829main_program = default_main_program ()
2930
3031
3132class ParameterChecks (unittest .TestCase ):
32- def check_parameter (self ):
33+ def test_parameter (self ):
3334 shape = [784 , 100 ]
3435 val = 1.0625
3536 b = main_program .global_block ()
@@ -43,13 +44,13 @@ def check_parameter(self):
4344 self .assertEqual ((784 , 100 ), param .shape )
4445 self .assertEqual (core .VarDesc .VarType .FP32 , param .dtype )
4546 self .assertEqual (0 , param .block .idx )
46- exe = Executor (core .CPUPlace ())
47+ exe = Executor (paddle .CPUPlace ())
4748 p = exe .run (main_program , fetch_list = [param ])[0 ]
48- self .assertTrue (np .allclose (p , np .ones (shape ) * val ))
49+ self .assertTrue (np .array_equal (p , np .ones (shape ) * val ))
4950 p = io .get_parameter_value_by_name ('fc.w' , exe , main_program )
50- self .assertTrue (np .allclose ( np . array ( p ) , np .ones (shape ) * val ))
51+ self .assertTrue (np .array_equal ( p , np .ones (shape ) * val ))
5152
52- def check_parambase (self ):
53+ def test_parambase (self ):
5354 with guard ():
5455 linear = paddle .nn .Linear (10 , 10 )
5556 param = linear .weight
@@ -71,7 +72,7 @@ def check_parambase(self):
7172 pram_copy2 = copy .deepcopy (param , memo )
7273 self .assertEqual (id (param_copy ), id (pram_copy2 ))
7374
74- def check_exceptions (self ):
75+ def test_exception (self ):
7576 b = main_program .global_block ()
7677 with self .assertRaises (ValueError ):
7778 b .create_parameter (
@@ -86,16 +87,52 @@ def check_exceptions(self):
8687 b .create_parameter (
8788 name = 'test' , shape = [- 1 ], dtype = 'float32' , initializer = None )
8889
90+ def test_parambase_to_vector (self ):
91+ with guard ():
92+ linear1 = paddle .nn .Linear (
93+ 10 ,
94+ 15 ,
95+ paddle .ParamAttr (
96+ initializer = paddle .nn .initializer .Constant (3. )))
8997
90- class TestParameter ( ParameterChecks ):
91- def _test_parameter ( self ):
92- self .check_parameter ( )
98+ vec = paddle . nn . utils . parameters_to_vector ( linear1 . parameters ())
99+ self . assertTrue ( isinstance ( vec , Variable ))
100+ self .assertTrue ( vec . shape , [ 165 ] )
93101
94- def test_parambase (self ):
95- self .check_parambase ()
102+ linear2 = paddle .nn .Linear (10 , 15 )
103+ paddle .nn .utils .vector_to_parameters (vec , linear2 .parameters ())
104+ self .assertTrue (
105+ np .array_equal (linear1 .weight .numpy (), linear2 .weight .numpy ()),
106+ True )
107+ self .assertTrue (
108+ np .array_equal (linear1 .bias .numpy (), linear2 .bias .numpy ()),
109+ True )
110+ self .assertTrue (linear2 .weight .is_leaf , True )
111+ self .assertTrue (linear2 .bias .is_leaf , True )
112+
113+ def test_parameter_to_vector (self ):
114+ main_program = paddle .static .Program ()
115+ start_program = paddle .static .Program ()
116+ with paddle .static .program_guard (main_program , start_program ):
117+ linear1 = paddle .nn .Linear (
118+ 10 ,
119+ 15 ,
120+ paddle .ParamAttr (
121+ initializer = paddle .nn .initializer .Constant (3. )))
122+
123+ vec = paddle .nn .utils .parameters_to_vector (linear1 .parameters ())
124+ self .assertTrue (isinstance (vec , Variable ))
125+ self .assertTrue (vec .shape , [165 ])
126+
127+ linear2 = paddle .nn .Linear (10 , 15 )
128+ paddle .nn .utils .vector_to_parameters (vec , linear2 .parameters ())
96129
97- def test_exceptions (self ):
98- self .check_exceptions ()
130+ exe = paddle .static .Executor ()
131+ exe .run (start_program )
132+ outs = exe .run (main_program ,
133+ fetch_list = [linear1 .parameters (), linear2 .parameters ()])
134+ self .assertTrue (np .array_equal (outs [0 ], outs [2 ]))
135+ self .assertTrue (np .array_equal (outs [1 ], outs [3 ]))
99136
100137
101138if __name__ == '__main__' :
0 commit comments