Skip to content

Commit 29eaaa1

Browse files
committed
first fix the UT
1 parent 341afb7 commit 29eaaa1

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

test/legacy_test/test_zero_dim_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ def test_setitem(self):
830830
np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10))
831831
self.assertEqual(x.grad.shape, [2, 3, 4, 5])
832832
x_grad_expected = np.ones((2, 3, 4, 5)) * 2
833+
x_grad_expected[1, 2, 3, 4] = 0
833834
np.testing.assert_allclose(x.grad, x_grad_expected)
834835

835836
# case2: 0-D Tensor indice in some axis
@@ -847,6 +848,7 @@ def test_setitem(self):
847848
self.assertEqual(out.shape, x.shape)
848849
np.testing.assert_allclose(out[1, 1], np.ones((4, 5)) * 0.5)
849850
x_grad_expected = np.ones((2, 3, 4, 5))
851+
x_grad_expected[1, 1] = 0
850852
np.testing.assert_allclose(x.grad, x_grad_expected)
851853

852854
# case3:0-D Tensor indice in some axis, value is a Tensor

test/xpu/test_zero_dim_tensor_xpu.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ def test_setitem(self):
441441
np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10))
442442
self.assertEqual(x.grad.shape, [2, 3, 4, 5])
443443
x_grad_expected = np.ones((2, 3, 4, 5)) * 2
444+
x_grad_expected[1, 2, 3, 4] = 0
444445
np.testing.assert_allclose(x.grad, x_grad_expected)
445446

446447
# case2: 0-D Tensor indice in some axis

0 commit comments

Comments
 (0)