Skip to content

Commit db7b57a

Browse files
0x45f0x45f
authored andcommitted
remove no_value using var.name (PaddlePaddle#36513)
* remove no_value using var.name * fix unit test for CI * fix unit test * add test case * fix test case * add more test case
1 parent 36edb0e commit db7b57a

File tree

5 files changed

+151
-14
lines changed

5 files changed

+151
-14
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from paddle.fluid.layers import assign, fill_constant, slice, reduce_all, reduce_any
2121
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
2222
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
23+
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
2324

2425

2526
def convert_while_loop(cond, body, loop_vars):
@@ -204,10 +205,45 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):
204205
205206
"""
206207
if isinstance(pred, Variable):
207-
return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
208-
return_vars)
208+
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,
209+
return_vars)
209210
else:
210-
return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
211+
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
212+
213+
return _remove_no_value_return_var(out)
214+
215+
216+
def _remove_no_value_return_var(out):
217+
if out and isinstance(out, tuple):
218+
processed_out = out
219+
align_ret = out[0]
220+
if isinstance(align_ret, tuple):
221+
for index, item in enumerate(align_ret):
222+
if isinstance(item, Variable) and (
223+
RETURN_NO_VALUE_VAR_NAME in item.name):
224+
# return None
225+
if index == 0:
226+
processed_out = (None, ) + out[1:]
227+
elif index == 1:
228+
processed_out = align_ret[:1] + out[1:]
229+
else:
230+
processed_out = (align_ret[:index], ) + out[1:]
231+
break
232+
233+
for index, item in enumerate(processed_out):
234+
if isinstance(item, Variable) and (
235+
RETURN_NO_VALUE_VAR_NAME in item.name):
236+
processed_out = processed_out[:index]
237+
238+
if not processed_out:
239+
return None
240+
elif len(processed_out) == 1:
241+
return processed_out[0]
242+
else:
243+
return processed_out
244+
245+
else:
246+
return out
211247

212248

213249
def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args,

python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ def create_fill_constant_node(name, value):
9393
func_code = "{} = paddle.fluid.layers.fill_constant(shape=[1], ".format(
9494
name)
9595
if isinstance(value, bool):
96-
func_code += "dtype='bool', value={})".format(value)
96+
func_code += "dtype='bool', value={}, name='{}')".format(value, name)
9797
return gast.parse(func_code).body[0]
9898
if isinstance(value, float):
99-
func_code += "dtype='float64', value={})".format(value)
99+
func_code += "dtype='float64', value={}, name='{}')".format(value, name)
100100
return gast.parse(func_code).body[0]
101101

102102
if isinstance(value, int):
103-
func_code += "dtype='int64', value={})".format(value)
103+
func_code += "dtype='int64', value={}, name='{}')".format(value, name)
104104
return gast.parse(func_code).body[0]
105105

106106

python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,5 +261,100 @@ def test_tensor_shape(self):
261261
self.assertTrue(np.array_equal(out.numpy(), x.numpy()))
262262

263263

264+
class TestIfElseNoValue(unittest.TestCase):
265+
def test_else_ret_none(self):
266+
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
267+
268+
@paddle.jit.to_static
269+
def with_common_value(x, use_cache=False):
270+
if use_cache:
271+
y = x + 1
272+
z = x + 2
273+
return y, z
274+
else:
275+
c = x + 1
276+
z = x - 1
277+
return None
278+
279+
@paddle.jit.to_static
280+
def without_common_value(x, use_cache=False):
281+
if use_cache:
282+
y = x + 1
283+
z = x + 2
284+
return y, z
285+
else:
286+
c = x + 1
287+
return None
288+
289+
out = with_common_value(input_x, False)
290+
self.assertIsNone(out)
291+
out = without_common_value(input_x, False)
292+
self.assertIsNone(out)
293+
294+
def test_else_ret_c(self):
295+
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
296+
297+
@paddle.jit.to_static
298+
def with_common_value(x, use_cache=False):
299+
if use_cache:
300+
y = x + 1
301+
z = x + 2
302+
return y, z
303+
else:
304+
c = x + 1
305+
z = x - 1
306+
return c
307+
308+
@paddle.jit.to_static
309+
def without_common_value(x, use_cache=False):
310+
if use_cache:
311+
y = x + 1
312+
z = x + 2
313+
return y, z
314+
else:
315+
c = x + 1
316+
return c
317+
318+
out = with_common_value(input_x, False)
319+
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
320+
out = without_common_value(input_x, False)
321+
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
322+
y, z = with_common_value(input_x, True)
323+
self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1))
324+
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2))
325+
326+
def test_else_ret_cz(self):
327+
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])
328+
329+
@paddle.jit.to_static
330+
def with_common_value(x, use_cache=False):
331+
if use_cache:
332+
y = x + 1
333+
z = x + 2
334+
return y, z, 1
335+
else:
336+
c = x + 1
337+
z = x - 1
338+
return c, z
339+
340+
@paddle.jit.to_static
341+
def without_common_value(x, use_cache=False):
342+
if use_cache:
343+
y = x + 1
344+
z = x + 2
345+
return y, z, 1
346+
else:
347+
c = x + 1
348+
d = x - 1
349+
return c, d
350+
351+
c, z = with_common_value(input_x, False)
352+
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
353+
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x - 1))
354+
c, d = without_common_value(input_x, False)
355+
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
356+
self.assertListEqual(paddle.tolist(d), paddle.tolist(input_x - 1))
357+
358+
264359
if __name__ == '__main__':
265360
unittest.main()

python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_source_code(func):
6464
class StaticCode1():
6565
def dyfunc_with_if_else(x_v, label=None):
6666
__return_value_init_0 = paddle.fluid.layers.fill_constant(
67-
shape=[1], dtype='float64', value=0.0)
67+
shape=[1], dtype='float64', value=0.0, name='__return_value_init_0')
6868
__return_value_0 = __return_value_init_0
6969

7070
def true_fn_0(x_v):
@@ -116,7 +116,7 @@ class StaticCode2():
116116
# TODO: Transform return statement
117117
def dyfunc_with_if_else(x_v, label=None):
118118
__return_value_init_1 = paddle.fluid.layers.fill_constant(
119-
shape=[1], dtype='float64', value=0.0)
119+
shape=[1], dtype='float64', value=0.0, name='__return_value_init_1')
120120
__return_value_1 = __return_value_init_1
121121

122122
def true_fn_3(x_v):

python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,22 @@ def test_feed_mismatch_shape(self):
5050
class TestVariableTransFunc(unittest.TestCase):
5151
def test_create_fill_constant_node(self):
5252
node = create_fill_constant_node("a", 1.0)
53-
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0)"
54-
self.assertEqual(ast_to_source_code(node).strip(), source)
53+
source = "a = paddle.fluid.layers.fill_constant(shape=[1], dtype='float64', value=1.0, name='a')"
54+
self.assertEqual(
55+
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
56+
source.replace(' ', ''))
5557

5658
node = create_fill_constant_node("b", True)
57-
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True)"
58-
self.assertEqual(ast_to_source_code(node).strip(), source)
59+
source = "b = paddle.fluid.layers.fill_constant(shape=[1], dtype='bool', value=True, name='b')"
60+
self.assertEqual(
61+
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
62+
source.replace(' ', ''))
5963

6064
node = create_fill_constant_node("c", 4293)
61-
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293)"
62-
self.assertEqual(ast_to_source_code(node).strip(), source)
65+
source = "c = paddle.fluid.layers.fill_constant(shape=[1], dtype='int64', value=4293, name='c')"
66+
self.assertEqual(
67+
ast_to_source_code(node).replace('\n', '').replace(' ', ''),
68+
source.replace(' ', ''))
6369

6470
self.assertIsNone(create_fill_constant_node("e", None))
6571
self.assertIsNone(create_fill_constant_node("e", []))

0 commit comments

Comments
 (0)