Skip to content

Commit 6cc1f71

Browse files
committed
add static mode backward test
1 parent c833770 commit 6cc1f71

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

test/legacy_test/test_set_value_op.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,5 +1978,59 @@ def test_check_grad(self):
19781978
self.check_grad_with_place(place, ['Input'], 'Out', check_dygraph=False)
19791979

19801980

1981+
class TestSetValueWithScalarInStatic(unittest.TestCase):
1982+
def setUp(self):
1983+
paddle.enable_static()
1984+
self.shape = (10, 2)
1985+
self.exe = paddle.static.Executor()
1986+
self.train_program = paddle.static.Program()
1987+
self.startup_program = paddle.static.Program()
1988+
1989+
def test_value_input_is_scalar(self):
1990+
with paddle.static.program_guard(
1991+
self.train_program, self.startup_program
1992+
):
1993+
x = paddle.ones(self.shape)
1994+
x.stop_gradient = False
1995+
y = x * 1
1996+
1997+
# mock test case x[0, 0] = 10 with no ValueTensor input
1998+
inputs = {
1999+
'Input': y,
2000+
}
2001+
attrs = {
2002+
'axes': [0, 1],
2003+
'starts': [0, 0],
2004+
'ends': [1, 1],
2005+
'steps': [1, 1],
2006+
'values': [10],
2007+
'shape': [1],
2008+
}
2009+
2010+
helper = LayerHelper("set_value")
2011+
out = helper.create_variable_for_type_inference(dtype=y.dtype)
2012+
2013+
helper.append_op(
2014+
type="set_value",
2015+
inputs=inputs,
2016+
outputs={'Out': out},
2017+
attrs=attrs,
2018+
)
2019+
2020+
np_data = np.ones(self.shape).astype('float32')
2021+
2022+
paddle.static.append_backward(out.sum())
2023+
res = self.exe.run(
2024+
self.train_program, fetch_list=[out, x.grad_name]
2025+
)
2026+
2027+
np_data[0, 0] = 10
2028+
expected_x_grad = np.ones(self.shape)
2029+
expected_x_grad[0, 0] = 0
2030+
2031+
np.testing.assert_array_equal(res[0], np_data)
2032+
np.testing.assert_array_equal(res[1], expected_x_grad)
2033+
2034+
19812035
if __name__ == '__main__':
19822036
unittest.main()

0 commit comments

Comments
 (0)