Skip to content

Commit 88677aa

Browse files
pkuzycEnigmatisms
authored andcommitted
fix the bug of setting grad's placements in dtensor_to_local (PaddlePaddle#71264)
1 parent 0ee2319 commit 88677aa

File tree

9 files changed

+59
-23
lines changed

9 files changed

+59
-23
lines changed

paddle/fluid/eager/api/manual/eager_manual/forwards/dtensor_to_local_fwd_func.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ paddle::Tensor dtensor_to_local_ad_function(
5757
ToTensorDistAttr(process_mesh, placements, input.dims());
5858

5959
grad_node->SetGradDistAttr(grad_dist_attr);
60+
grad_node->SetGradProcessMesh(process_mesh);
61+
grad_node->SetGradPlacements(placements);
6062
}
6163

6264
// Forward API Call

paddle/fluid/eager/api/manual/eager_manual/nodes/dtensor_to_local_node.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,15 @@ DtensorToLocalGradNode::operator()(
7878
VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str);
7979
}
8080

81+
std::shared_ptr<phi::DenseTensor> grad_out_ptr =
82+
std::static_pointer_cast<phi::DenseTensor>(grad_out.impl());
8183
// Backward call dtensor_to_local_func function
8284
auto dist_grad_ptr = std::make_shared<phi::distributed::DistTensor>(
83-
grad_out.dims(), grad_dist_attr_);
85+
grad_out_ptr,
86+
out_metas[0][0].DistTensorGlobalDims(),
87+
grad_process_mesh_,
88+
grad_placements_);
8489

85-
*(dist_grad_ptr->unsafe_mutable_value()) =
86-
*(static_cast<phi::DenseTensor*>(grad_out.impl().get()));
8790
grad_input.set_impl(dist_grad_ptr);
8891

8992
VLOG(5) << "Finish C++ API: dtensor_to_local_func";

paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,11 +493,21 @@ class DtensorToLocalGradNode : public egr::GradNodeBase {
493493
grad_dist_attr_ = dist_attr;
494494
}
495495

496+
void SetGradPlacements(const phi::distributed::Placements& placements) {
497+
grad_placements_ = placements;
498+
}
499+
500+
void SetGradProcessMesh(const phi::distributed::ProcessMesh& process_mesh) {
501+
grad_process_mesh_ = process_mesh;
502+
}
503+
496504
private:
497505
// TensorWrappers
498506
egr::TensorWrapper input_;
499507

500508
phi::distributed::TensorDistAttr grad_dist_attr_;
509+
phi::distributed::Placements grad_placements_;
510+
phi::distributed::ProcessMesh grad_process_mesh_;
501511
};
502512

503513
class DtensorFromLocalGradNode : public egr::GradNodeBase {

paddle/fluid/framework/new_executor/instruction/instruction_util.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,12 @@ phi::DeviceContext* ParseDeviceContext(pir::Operation* op,
173173
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
174174
op_name.compare(paddle::dialect::MpAllreduceSum_Op::name()) == 0 ||
175175
op_name.compare(paddle::dialect::CIdentity_Op::name()) == 0 ||
176-
op_name.compare(paddle::dialect::CConcatOp::name()) == 0) {
176+
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
177+
op_name.compare(paddle::dialect::CConcatOp::name()) == 0 ||
178+
op_name.compare(paddle::dialect::AllGatherOp::name()) == 0 ||
179+
op_name.compare(paddle::dialect::AllToAllOp::name()) == 0 ||
180+
op_name.compare(
181+
paddle::dialect::CSoftmaxWithCrossEntropyOp::name()) == 0) {
177182
if (phi::is_gpu_place(place) && execution_stream == kDefaultStream) {
178183
if (origin_dev_ctx != nullptr) {
179184
// set stream

paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,6 +1523,7 @@ std::unordered_map<std::string, std::set<std::string>> GetNoNeedBufferValues(
15231523
no_need_buffer_vars.insert(name);
15241524
} else {
15251525
no_need_buffer_vars.erase(name);
1526+
break;
15261527
}
15271528
}
15281529
}
@@ -1535,6 +1536,7 @@ std::unordered_map<std::string, std::set<std::string>> GetNoNeedBufferValues(
15351536
no_need_buffer_vars.insert(name);
15361537
} else {
15371538
no_need_buffer_vars.erase(name);
1539+
break;
15381540
}
15391541
}
15401542
}

paddle/fluid/pir/dialect/distributed/ir/dist_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class DtensorToLocalOp
108108
// void VerifySig();
109109
};
110110

111-
class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp> {
111+
class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp, VjpInterface> {
112112
public:
113113
using Op::Op;
114114
static const char* name() { return "dist_op.moe_sub_mesh_tensors"; }

python/paddle/distributed/auto_parallel/api.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,7 @@ def _cal_global_shape(local_shape, mesh, placements):
495495
def moe_global_mesh_tensor(
496496
local_tensor_list, mesh, placements, local_mesh_dim=-1
497497
):
498+
placements = copy.deepcopy(placements)
498499
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
499500
mesh, placements, local_mesh_dim
500501
)
@@ -548,16 +549,17 @@ def moe_global_mesh_tensor(
548549
global_dims = _cal_global_shape(
549550
local_tensor._local_shape, mesh, placements
550551
)
551-
return paddle.jit.dy2static.py_layer.StaticPyLayer(
552-
_moe_global_mesh_tensor
553-
).apply(
552+
dist_tensor = paddle._C_ops.moe_global_mesh_tensor(
554553
local_tensor_list,
555554
local_mesh_list,
556555
local_placements,
557556
mesh,
558557
placements,
559558
global_dims,
560559
)
560+
dist_tensor.stop_gradient = local_tensor_list[0].stop_gradient
561+
dist_tensor.persistable = local_tensor_list[0].persistable
562+
return dist_tensor
561563
else:
562564
raise NotImplementedError(
563565
"dtensor_from_local_list() are only supported in dynamic and pir mode."
@@ -691,6 +693,7 @@ def moe_sub_mesh_tensors(
691693
"""
692694
Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``.
693695
"""
696+
global_placements = copy.deepcopy(global_placements)
694697
local_mesh_list, local_placements = _get_sub_meshes_and_local_placements(
695698
global_mesh, global_placements, local_mesh_dim
696699
)
@@ -705,17 +708,17 @@ def moe_sub_mesh_tensors(
705708
global_placements,
706709
)
707710
elif paddle.framework.in_pir_mode():
708-
709-
return paddle.jit.dy2static.py_layer.StaticPyLayer(
710-
_moe_sub_mesh_tensors
711-
).apply(
711+
local_tensors = paddle._C_ops.moe_sub_mesh_tensors(
712712
dist_tensor,
713713
local_mesh_list,
714714
local_placements,
715-
local_mesh_dim,
716715
global_mesh,
717716
global_placements,
718717
)
718+
for local_tensor in local_tensors:
719+
local_tensor.stop_gradient = dist_tensor.stop_gradient
720+
local_tensor.persistable = dist_tensor.persistable
721+
return local_tensors
719722
else:
720723
raise NotImplementedError(
721724
"moe_sub_mesh_tensors is only supported in dynamic mode."

python/paddle/distributed/auto_parallel/placement_type.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,10 @@ def to_dim_map(placements, tensor_dims):
8686
if placement.is_shard():
8787
shard_dim = cast(Shard, placement).get_dim()
8888
if dim_map[shard_dim] > -1:
89-
raise Exception(
90-
"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}"
89+
import logging
90+
91+
logging.warning(
92+
f"Tensor dim {shard_dim} is already sharded on mesh dim {dim_map[shard_dim]}."
9193
)
9294

9395
dim_map[shard_dim] = i

test/auto_parallel/pir/test_moe_api.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,23 @@ def check_results(
127127
local_meshes,
128128
local_dims_mapping,
129129
):
130-
# local_tensors_from_dtensor op
131-
self.check_dist_attr(ops[4], local_meshes, local_dims_mapping)
132-
# dtensor_from_local_list op
133-
self.check_dist_attr(ops[5], [global_mesh], global_dims_mapping)
134-
# grad op for dtensor_from_local_list
135-
self.check_dist_attr(ops[10], local_meshes, local_dims_mapping)
136-
# grad op for local_tensors_from_dtensor op
137-
self.check_dist_attr(ops[11], [global_mesh], global_dims_mapping)
130+
op_names = [
131+
"dist_op.moe_sub_mesh_tensors",
132+
"dist_op.moe_global_mesh_tensor",
133+
]
134+
ops_to_check = [op for op in ops if op.name() in op_names]
135+
# moe_sub_mesh_tensors op
136+
self.check_dist_attr(ops_to_check[0], local_meshes, local_dims_mapping)
137+
# moe_global_mesh_tensor op
138+
self.check_dist_attr(
139+
ops_to_check[1], [global_mesh], global_dims_mapping
140+
)
141+
# grad op for moe_global_mesh_tensor
142+
self.check_dist_attr(ops_to_check[2], local_meshes, local_dims_mapping)
143+
# grad op for moe_sub_mesh_tensors op
144+
self.check_dist_attr(
145+
ops_to_check[3], [global_mesh], global_dims_mapping
146+
)
138147

139148

140149
if __name__ == "__main__":

0 commit comments

Comments
 (0)