|
17 | 17 | import unittest |
18 | 18 | import numpy as np |
19 | 19 | import paddle.fluid as fluid |
| 20 | +import paddle |
20 | 21 | from op_test import OpTest |
21 | 22 |
|
22 | 23 |
|
@@ -69,6 +70,25 @@ def init_test_case(self): |
69 | 70 | self.new_shape = (36, 16) |
70 | 71 |
|
71 | 72 |
|
| 73 | +class TestStaticFlattenPythonAPI(unittest.TestCase): |
| 74 | + def execute_api(self, x, axis=1): |
| 75 | + return fluid.layers.flatten(x, axis=axis) |
| 76 | + |
| 77 | + def test_static_api(self): |
| 78 | + paddle.enable_static() |
| 79 | + np_x = np.random.rand(2, 3, 4, 4).astype('float32') |
| 80 | + |
| 81 | + main_prog = paddle.static.Program() |
| 82 | + with paddle.static.program_guard(main_prog, paddle.static.Program()): |
| 83 | + x = paddle.static.data( |
| 84 | + name="x", shape=[-1, 3, -1, -1], dtype='float32') |
| 85 | + out = self.execute_api(x, axis=2) |
| 86 | + |
| 87 | + exe = paddle.static.Executor(place=paddle.CPUPlace()) |
| 88 | + fetch_out = exe.run(main_prog, feed={"x": np_x}, fetch_list=[out]) |
| 89 | + self.assertTrue((6, 16) == fetch_out[0].shape) |
| 90 | + |
| 91 | + |
72 | 92 | class TestFlatten2OpError(unittest.TestCase): |
73 | 93 | def test_errors(self): |
74 | 94 | with fluid.program_guard(fluid.Program(), fluid.Program()): |
|
0 commit comments