Skip to content

Commit c7d49ec

Browse files
[Comp] Fix take_along_axis_grad when duplicated entries in indices (#70250)
* fix take_along_axis_grad when duplicated entries in indices * update unitest * add grad check * use include_self=True to avoid BUG
1 parent 9986a56 commit c7d49ec

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3341,8 +3341,12 @@ void take_along_axis_grad(const Tensor& arr,
33413341
arr_cast.dtype(),
33423342
arr_cast.place());
33433343
}
3344-
auto arr_grad_tmp =
3345-
put_along_axis<T>(zero_tensor, indices, out_grad_cast, axis);
3344+
auto arr_grad_tmp = put_along_axis<T>(zero_tensor,
3345+
indices,
3346+
out_grad_cast,
3347+
axis,
3348+
/*reduce*/ "add",
3349+
/*include_self*/ true);
33463350
set_output<T>(ConvertToOrig<T>(arr_grad_tmp, arr.dtype()), arr_grad);
33473351
}
33483352
}

test/legacy_test/test_take_along_axis_op.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,36 @@ def init_data(self):
7070
self.axis_type = "int64"
7171

7272

73+
class TestTakeAlongAxisDuplicatedIndices(TestTakeAlongAxisOp):
74+
def init_data(self):
75+
self.dtype = np.float32
76+
self.x_type = "float32"
77+
self.x_shape = (5, 6, 7)
78+
self.index_type = "int64"
79+
self.axis = 2
80+
dim_size = self.x_shape[self.axis]
81+
self.index = (
82+
np.asarray([-dim_size, -dim_size, dim_size - 1, dim_size - 1, 0])
83+
.astype(self.index_type)
84+
.reshape([5, 1, 1])
85+
)
86+
self.axis_type = "int64"
87+
88+
def test_check_output(self):
89+
self.check_output(
90+
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
91+
)
92+
93+
def test_check_grad(self):
94+
self.check_grad(
95+
['Input'],
96+
'Result',
97+
check_cinn=self.check_cinn,
98+
check_pir=True,
99+
check_prim_pir=True,
100+
)
101+
102+
73103
class TestTakeAlongAxisFP16Op(TestTakeAlongAxisOp):
74104
def init_data(self):
75105
self.dtype = np.float16

0 commit comments

Comments
 (0)