Skip to content

Commit f4800e7

Browse files
committed
[Phi] Add swish yaml and final state api (PaddlePaddle#41479)
* add swish yaml and final state api * skip mkldnn test * fix grad mkldnn test
1 parent 54991bd commit f4800e7

File tree

5 files changed

+36
-3
lines changed

5 files changed

+36
-3
lines changed

python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def setUp(self):
113113
super(TestMKLDNNSwishDim2, self).setUp()
114114

115115
self.attrs["use_mkldnn"] = True
116+
self.check_eager = False
116117

117118
def init_dtype(self):
118119
self.dtype = np.float32
@@ -284,6 +285,7 @@ def setUp(self):
284285
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
285286
self.outputs = {'Out': out}
286287
self.attrs = {"use_mkldnn": True, "beta": beta}
288+
self.check_eager = False
287289

288290
def init_dtype(self):
289291
self.dtype = np.float32

python/paddle/fluid/tests/unittests/test_activation_op.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2928,7 +2928,9 @@ def ref_swish(x):
29282928
class TestSwish(TestActivation):
29292929
def setUp(self):
29302930
self.op_type = "swish"
2931+
self.python_api = paddle.nn.functional.swish
29312932
self.init_dtype()
2933+
self.check_eager = True
29322934

29332935
np.random.seed(1024)
29342936
x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
@@ -2940,7 +2942,10 @@ def setUp(self):
29402942
def test_check_grad(self):
29412943
if self.dtype == np.float16:
29422944
return
2943-
self.check_grad(['X'], 'Out')
2945+
check_eager = False
2946+
if hasattr(self, 'check_eager'):
2947+
check_eager = self.check_eager
2948+
self.check_grad(['X'], 'Out', check_eager=check_eager)
29442949

29452950

29462951
class TestSwishAPI(unittest.TestCase):
@@ -2975,6 +2980,10 @@ def test_dygraph_api(self):
29752980
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
29762981
paddle.enable_static()
29772982

2983+
def test_dygraph_final_state_api(self):
2984+
with _test_eager_guard():
2985+
self.test_dygraph_api()
2986+
29782987
def test_fluid_api(self):
29792988
paddle.enable_static()
29802989
with fluid.program_guard(fluid.Program()):

python/paddle/nn/functional/activation.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,8 +1175,9 @@ def swish(x, name=None):
11751175
x = paddle.to_tensor(np.array([-2., 0., 1.]))
11761176
out = F.swish(x) # [-0.238406, 0., 0.731059]
11771177
"""
1178-
1179-
if in_dynamic_mode():
1178+
if in_dygraph_mode():
1179+
return _C_ops.final_state_swish(x, 1.0)
1180+
if _in_legacy_dygraph():
11801181
return _C_ops.swish(x, 'beta', 1.0)
11811182

11821183
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')

python/paddle/utils/code_gen/api.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1732,6 +1732,17 @@
17321732
data_type : x
17331733
backward : sum_grad
17341734

1735+
# The python API paddle.nn.functional.swish has no `bete` argument, it may be removed later
1736+
- api : swish
1737+
args : (Tensor x, float beta=1.0)
1738+
output : Tensor(out)
1739+
infer_meta :
1740+
func : UnchangedInferMeta
1741+
param : [x]
1742+
kernel :
1743+
func : swish
1744+
backward : swish_grad
1745+
17351746
# take_along_axis
17361747
- api : take_along_axis
17371748
args : (Tensor x, Tensor index, int axis)

python/paddle/utils/code_gen/backward.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,6 +1317,16 @@
13171317
kernel :
13181318
func : sum_grad
13191319

1320+
- backward_api : swish_grad
1321+
forward : swish (Tensor x, float beta=1.0) -> Tensor(out)
1322+
args : (Tensor x, Tensor out_grad, float bete=1.0)
1323+
output : Tensor(x_grad)
1324+
infer_meta :
1325+
func : GeneralUnaryGradInferMeta
1326+
param : [x]
1327+
kernel :
1328+
func : swish_grad
1329+
13201330
- backward_api : take_along_axis_grad
13211331
forward : take_along_axis (Tensor x, Tensor index, int axis) -> Tensor(out)
13221332
args : (Tensor x, Tensor index, Tensor out_grad, int axis)

0 commit comments

Comments
 (0)