Skip to content

Commit a4d7ff5

Browse files
[AutoParallel] Refine SubMeshDim func for global_and_sub_mesh reshard (#68115)
1 parent cd7c853 commit a4d7ff5

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

paddle/phi/core/distributed/auto_parallel/process_mesh.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,12 +217,17 @@ int SubMeshDim(const ProcessMesh &global_mesh, const ProcessMesh &sub_mesh) {
217217
return -1;
218218
}
219219

220-
auto it = std::find(sub_shape.begin(), sub_shape.end(), 1);
221-
if (it == sub_shape.end()) {
220+
// e.g.
221+
// global_mesh: shape = [1,2], process_ids = [0,1]; sub_mesh: shape = [1, 1],
222+
// process_ids = [0] global_mesh: shape = [2,2], process_ids = [0,1,2,3];
223+
// sub_mesh: shape = [2, 1], process_ids = [0, 2]
224+
auto it =
225+
std::mismatch(sub_shape.begin(), sub_shape.end(), global_shape.begin());
226+
if (it.first == sub_shape.end()) {
222227
return -1;
223228
}
224229

225-
sub_dim = it - sub_shape.begin();
230+
sub_dim = it.first - sub_shape.begin();
226231
std::vector<ProcessMesh> sub_meshes = SplitMesh(global_mesh, sub_dim);
227232
if (std::find(sub_meshes.begin(), sub_meshes.end(), sub_mesh) !=
228233
sub_meshes.end()) {

0 commit comments

Comments
 (0)