Skip to content

Commit 3ad5f26

Browse files
authored
fix the bug of scatter_grad spmd (#67283)
1 parent 44fbb36 commit 3ad5f26

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

paddle/phi/infermeta/spmd_rules/scatter.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,17 @@ SpmdInfo ScatterGradInferSpmd(const DistMetaTensor& index,
178178
// the batch axis of index, updates, out_grad must be replicated
179179
std::vector<int64_t> index_dims_mapping(index_dims_mapping_src);
180180
index_dims_mapping[0] = -1;
181+
std::vector<int64_t> updates_dims_mapping(updates_dims_mapping_src);
182+
updates_dims_mapping[0] = -1;
181183
std::vector<int64_t> out_grad_dims_mapping(out_grad_dims_mapping_src);
182184
out_grad_dims_mapping[0] = -1;
183185

184186
TensorDistAttr index_dist_attr_dst =
185187
CopyTensorDistAttrForOutput(index_dist_attr_src);
186188
index_dist_attr_dst.set_dims_mapping(index_dims_mapping);
189+
TensorDistAttr updates_dist_attr_dst =
190+
CopyTensorDistAttrForOutput(updates_dist_attr_src);
191+
updates_dist_attr_dst.set_dims_mapping(updates_dims_mapping);
187192
TensorDistAttr out_grad_dist_attr_dst =
188193
CopyTensorDistAttrForOutput(out_grad_dist_attr_src);
189194
out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping);
@@ -199,7 +204,7 @@ SpmdInfo ScatterGradInferSpmd(const DistMetaTensor& index,
199204
TensorDistAttr updates_grad_dist_attr =
200205
PADDLE_GET_CONST(TensorDistAttr, spmd_info.second[0]);
201206

202-
return {{index_dist_attr_dst, updates_dist_attr_src, out_grad_dist_attr_dst},
207+
return {{index_dist_attr_dst, updates_dist_attr_dst, out_grad_dist_attr_dst},
203208
{x_grad_dist_attr, updates_grad_dist_attr}};
204209
}
205210

test/cpp/auto_parallel/spmd_rule_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,7 +1699,7 @@ TEST(ScatterGradInferSpmd, Ctor) {
16991699
std::vector<int64_t>({-1, -1, 1}));
17001700

17011701
// [0], [0, -1, 1], [-1, 0, 1] -->
1702-
// inputs: [-1], [0, -1, 1], [-1, 0, 1]
1702+
// inputs: [-1], [-1, -1, 1], [-1, 0, 1]
17031703
// x_grad: [-1, 0, 1], updates_grad: [-1, 0, 1]
17041704
index_dist_attr.set_dims_mapping({0});
17051705
updates_dist_attr.set_dims_mapping({0, -1, 1});
@@ -1716,7 +1716,7 @@ TEST(ScatterGradInferSpmd, Ctor) {
17161716

17171717
EXPECT_EQ(get_dims_mapping(spmdinfo.first[0]), std::vector<int64_t>({-1}));
17181718
EXPECT_EQ(get_dims_mapping(spmdinfo.first[1]),
1719-
std::vector<int64_t>({0, -1, 1}));
1719+
std::vector<int64_t>({-1, -1, 1}));
17201720
EXPECT_EQ(get_dims_mapping(spmdinfo.first[2]),
17211721
std::vector<int64_t>({-1, 0, 1}));
17221722
EXPECT_EQ(get_dims_mapping(spmdinfo.second[0]),

0 commit comments

Comments
 (0)