Skip to content

Commit 8381047

Browse files
committed
fix flatten infershape; test=develop
1 parent 6fc314d commit 8381047

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

python/paddle/fluid/tests/unittests/test_flatten2_op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import unittest
1818
import numpy as np
1919
import paddle.fluid as fluid
20+
import paddle
2021
from op_test import OpTest
2122

2223

@@ -69,6 +70,25 @@ def init_test_case(self):
6970
self.new_shape = (36, 16)
7071

7172

73+
class TestStaticFlattenPythonAPI(unittest.TestCase):
74+
def execute_api(self, x, axis=1):
75+
return fluid.layers.flatten(x, axis=axis)
76+
77+
def test_static_api(self):
78+
paddle.enable_static()
79+
np_x = np.random.rand(2, 3, 4, 4).astype('float32')
80+
81+
main_prog = paddle.static.Program()
82+
with paddle.static.program_guard(main_prog, paddle.static.Program()):
83+
x = paddle.static.data(
84+
name="x", shape=[-1, 3, -1, -1], dtype='float32')
85+
out = self.execute_api(x, axis=2)
86+
87+
exe = paddle.static.Executor(place=paddle.CPUPlace())
88+
fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out])
89+
self.assertTrue((6, 16) == fetch_out[0].shape)
90+
91+
7292
class TestFlatten2OpError(unittest.TestCase):
7393
def test_errors(self):
7494
with fluid.program_guard(fluid.Program(), fluid.Program()):

0 commit comments

Comments
 (0)