Skip to content

Commit f04aa61

Browse files
authored
Fix dtensor_to_local backward bugs (#71232)
* Fix dtensor_to_local backward bugs * Fix CI errors
1 parent 4135912 commit f04aa61

File tree

21 files changed

+205
-98
lines changed

21 files changed

+205
-98
lines changed

paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ paddle::Tensor reshard_ad_function(
5757
const paddle::Tensor& tensor,
5858
const phi::distributed::TensorDistAttr dist_attr);
5959

60-
paddle::Tensor dtensor_to_local_ad_function(const paddle::Tensor& input);
60+
paddle::Tensor dtensor_to_local_ad_function(
61+
const paddle::Tensor& input,
62+
const phi::distributed::ProcessMesh& processmesh,
63+
const phi::distributed::Placements& placements);
6164

6265
paddle::Tensor dtensor_from_local_ad_function(
6366
const paddle::Tensor& input,

paddle/fluid/eager/api/manual/eager_manual/forwards/dtensor_to_local_fwd_func.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
#include "paddle/fluid/eager/api/utils/global_utils.h"
1919
#include "paddle/phi/core/platform/profiler/event_tracing.h"
2020

21-
paddle::Tensor dtensor_to_local_ad_function(const paddle::Tensor& input) {
21+
paddle::Tensor dtensor_to_local_ad_function(
22+
const paddle::Tensor& input,
23+
const phi::distributed::ProcessMesh& process_mesh,
24+
const phi::distributed::Placements& placements) {
2225
#ifdef PADDLE_WITH_DISTRIBUTE
2326
VLOG(3) << "Running AD API: "
2427
<< "dtensor_to_local dygraph";
@@ -49,6 +52,11 @@ paddle::Tensor dtensor_to_local_ad_function(const paddle::Tensor& input) {
4952

5053
// Set TensorWrappers for Forward Inputs if needed
5154
grad_node->SetTensorWrapperNoNeedBuffer_Input(input);
55+
56+
phi::distributed::TensorDistAttr grad_dist_attr =
57+
ToTensorDistAttr(process_mesh, placements, input.dims());
58+
59+
grad_node->SetGradDistAttr(grad_dist_attr);
5260
}
5361

5462
// Forward API Call

paddle/fluid/eager/api/manual/eager_manual/nodes/dtensor_to_local_node.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ DtensorToLocalGradNode::operator()(
4747

4848
// Collect GradIn Tensors, Attrs and Recovered TensorWrappers
4949
auto input = egr::EagerUtils::RecoverTensorWrapper(&this->input_);
50-
const auto& dist_attr =
51-
std::static_pointer_cast<phi::distributed::DistTensor>(input.impl())
52-
->dist_attr();
50+
5351
auto& grad_out = hooked_grad[0][0];
5452
// Prepare Grad function call
5553

@@ -82,7 +80,7 @@ DtensorToLocalGradNode::operator()(
8280

8381
// Backward call dtensor_to_local_func function
8482
auto dist_grad_ptr = std::make_shared<phi::distributed::DistTensor>(
85-
grad_out.dims(), dist_attr);
83+
grad_out.dims(), grad_dist_attr_);
8684

8785
*(dist_grad_ptr->unsafe_mutable_value()) =
8886
*(static_cast<phi::DenseTensor*>(grad_out.impl().get()));

paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,15 @@ class DtensorToLocalGradNode : public egr::GradNodeBase {
489489
input_ = egr::TensorWrapper(input, true);
490490
}
491491

492+
void SetGradDistAttr(const phi::distributed::TensorDistAttr& dist_attr) {
493+
grad_dist_attr_ = dist_attr;
494+
}
495+
492496
private:
493497
// TensorWrappers
494498
egr::TensorWrapper input_;
499+
500+
phi::distributed::TensorDistAttr grad_dist_attr_;
495501
};
496502

497503
class DtensorFromLocalGradNode : public egr::GradNodeBase {

paddle/fluid/pir/dialect/distributed/ir/dist_api.cc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,23 @@ pir::Value dtensor_from_local(const pir::Value& x,
8585
.result(0);
8686
}
8787

88-
pir::Value dtensor_to_local(const pir::Value& x) {
89-
return ApiBuilder::Instance().GetBuilder()->Build<DtensorToLocalOp>(x).result(
90-
0);
88+
pir::Value dtensor_to_local(
89+
const pir::Value& x,
90+
const phi::distributed::ProcessMesh& process_mesh,
91+
const std::vector<int64_t>& dims_mapping,
92+
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
93+
pir::IrContext* ctx = pir::IrContext::Instance();
94+
TensorDistAttribute grad_dist_attr =
95+
TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status);
96+
return dtensor_to_local(x, grad_dist_attr);
97+
}
98+
99+
pir::Value dtensor_to_local(const pir::Value& x,
100+
const TensorDistAttribute& grad_dist_attr) {
101+
return ApiBuilder::Instance()
102+
.GetBuilder()
103+
->Build<DtensorToLocalOp>(x, grad_dist_attr)
104+
.result(0);
91105
}
92106

93107
std::vector<pir::Value> moe_sub_mesh_tensors(

paddle/fluid/pir/dialect/distributed/ir/dist_api.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ pir::Value dtensor_from_local(
5050
pir::Value dtensor_from_local(const pir::Value& x,
5151
const TensorDistAttribute& tensor_dist_attr);
5252

53-
pir::Value dtensor_to_local(const pir::Value& x);
53+
pir::Value dtensor_to_local(
54+
const pir::Value& x,
55+
const phi::distributed::ProcessMesh& process_mesh,
56+
const std::vector<int64_t>& dims_mapping,
57+
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {});
58+
pir::Value dtensor_to_local(const pir::Value& x,
59+
const TensorDistAttribute& grad_dist_attr);
5460

5561
std::vector<pir::Value> moe_sub_mesh_tensors(
5662
const pir::Value& input,

paddle/fluid/pir/dialect/distributed/ir/dist_op.cc

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ std::vector<std::vector<pir::Value>> DtensorFromLocalOp::Vjp(
413413
->result(0);
414414
}
415415

416-
auto grad_op = builder.Build<DtensorToLocalOp>(out_grad);
416+
auto grad_op = builder.Build<DtensorToLocalOp>(
417+
out_grad, dist_type.tensor_dist_attr() /*unused*/);
417418

418419
VLOG(6) << "End call vjp for dtensor_from_local op.";
419420

@@ -422,13 +423,23 @@ std::vector<std::vector<pir::Value>> DtensorFromLocalOp::Vjp(
422423

423424
void DtensorToLocalOp::Build(pir::Builder& builder,
424425
pir::OperationArgument& argument,
425-
pir::Value input) {
426+
pir::Value input,
427+
TensorDistAttribute grad_dist_attr) {
426428
VLOG(4) << "Start build DtensorToLocalOp";
429+
paddle::dialect::DistDenseTensorType input_tensor_type;
430+
if (input.type().isa<paddle::dialect::DistDenseTensorType>()) {
431+
input_tensor_type =
432+
input.type().dyn_cast<paddle::dialect::DistDenseTensorType>();
433+
} else {
434+
PADDLE_THROW(common::errors::Unimplemented(
435+
"Only support paddle::dialect::DistDenseTensorType"));
436+
}
427437

428438
VLOG(4) << "Builder construction inputs";
429439
argument.AddInput(input);
430440

431441
VLOG(4) << "Builder construction attributes";
442+
argument.AddAttribute("grad_dist_attr", grad_dist_attr);
432443

433444
VLOG(4) << "Builder construction outputs";
434445

@@ -494,9 +505,11 @@ std::vector<std::vector<pir::Value>> DtensorToLocalOp::Vjp(
494505
"dtensor_from_local op's outputs grad[0] size should be 1"));
495506

496507
auto& builder = *ApiBuilder::Instance().GetBuilder();
508+
const auto& grad_dist_attr =
509+
op->attribute<paddle::dialect::TensorDistAttribute>("grad_dist_attr");
497510

498-
auto grad_op = builder.Build<DtensorFromLocalOp>(
499-
out_grads[0][0], dist_type.tensor_dist_attr());
511+
auto grad_op =
512+
builder.Build<DtensorFromLocalOp>(out_grads[0][0], grad_dist_attr);
500513

501514
VLOG(6) << "End call vjp for dtensor_from_local op.";
502515

paddle/fluid/pir/dialect/distributed/ir/dist_op.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class DtensorToLocalOp
9494
static constexpr uint32_t attributes_num = 0;
9595
TEST_API static void Build(pir::Builder& builder, // NOLINT
9696
pir::OperationArgument& argument, // NOLINT
97-
pir::Value input);
97+
pir::Value input,
98+
TensorDistAttribute grad_dist_attr);
9899

99100
static OpInfoTuple GetOpInfo();
100101
static std::vector<std::vector<pir::Value>> Vjp(

paddle/fluid/pir/dialect/operator/utils/utils.cc

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
#include "paddle/common/errors.h"
2020
#include "paddle/fluid/framework/phi_utils.h"
21+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
2122
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
2223
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
2324
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
@@ -57,43 +58,43 @@ enum class AttrType {
5758
DOUBLE,
5859

5960
ARRAY,
61+
STRING,
62+
TENSOR_NAME,
63+
DATA_TYPE,
6064
INT_ARRAY,
65+
PLACE,
66+
TensorDist,
6167

6268
SCALAR,
63-
DATA_TYPE,
6469
DATA_LAYOUT,
65-
PLACE,
66-
67-
STRING,
68-
69-
TENSOR_NAME,
70-
7170
NUM_ATTR_TYPES,
7271
};
7372

7473
static inline AttrType GetAttributeType(const pir::Attribute& attr) {
7574
if (attr.isa<pir::BoolAttribute>()) {
7675
return AttrType::BOOL;
77-
} else if (attr.isa<pir::FloatAttribute>()) {
78-
return AttrType::FLOAT;
79-
} else if (attr.isa<pir::DoubleAttribute>()) {
80-
return AttrType::DOUBLE;
8176
} else if (attr.isa<pir::Int32Attribute>()) {
8277
return AttrType::INT32;
8378
} else if (attr.isa<pir::Int64Attribute>()) {
8479
return AttrType::INT64;
80+
} else if (attr.isa<pir::FloatAttribute>()) {
81+
return AttrType::FLOAT;
82+
} else if (attr.isa<pir::DoubleAttribute>()) {
83+
return AttrType::DOUBLE;
8584
} else if (attr.isa<pir::ArrayAttribute>()) {
8685
return AttrType::ARRAY;
8786
} else if (attr.isa<pir::StrAttribute>()) {
8887
return AttrType::STRING;
89-
} else if (attr.isa<paddle::dialect::IntArrayAttribute>()) {
90-
return AttrType::INT_ARRAY;
88+
} else if (attr.isa<pir::TensorNameAttribute>()) {
89+
return AttrType::TENSOR_NAME;
9190
} else if (attr.isa<paddle::dialect::DataTypeAttribute>()) {
9291
return AttrType::DATA_TYPE;
92+
} else if (attr.isa<paddle::dialect::IntArrayAttribute>()) {
93+
return AttrType::INT_ARRAY;
9394
} else if (attr.isa<paddle::dialect::PlaceAttribute>()) {
9495
return AttrType::PLACE;
95-
} else if (attr.isa<pir::TensorNameAttribute>()) {
96-
return AttrType::TENSOR_NAME;
96+
} else if (attr.isa<paddle::dialect::TensorDistAttribute>()) {
97+
return AttrType::TensorDist;
9798
} else {
9899
PADDLE_THROW(common::errors::Unimplemented(
99100
"Unsupported ir Attribute type when casting it into "
@@ -110,14 +111,6 @@ static std::function<T(const pir::Attribute& attr)> GetAttrCast(
110111
[](const pir::Attribute& attr) {
111112
return T{attr.dyn_cast<pir::BoolAttribute>().data()};
112113
}},
113-
{AttrType::FLOAT,
114-
[](const pir::Attribute& attr) {
115-
return T{attr.dyn_cast<pir::FloatAttribute>().data()};
116-
}},
117-
{AttrType::DOUBLE,
118-
[](const pir::Attribute& attr) {
119-
return T{attr.dyn_cast<pir::DoubleAttribute>().data()};
120-
}},
121114
{AttrType::INT32,
122115
[](const pir::Attribute& attr) {
123116
return T{attr.dyn_cast<pir::Int32Attribute>().data()};
@@ -126,28 +119,13 @@ static std::function<T(const pir::Attribute& attr)> GetAttrCast(
126119
[](const pir::Attribute& attr) {
127120
return T{attr.dyn_cast<pir::Int64Attribute>().data()};
128121
}},
129-
{AttrType::INT_ARRAY,
130-
[](const pir::Attribute& attr) {
131-
return T{attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
132-
.data()
133-
.GetData()};
134-
}},
135-
{AttrType::STRING,
136-
[](const pir::Attribute& attr) {
137-
return T{attr.dyn_cast<pir::StrAttribute>().AsString()};
138-
}},
139-
{AttrType::DATA_TYPE,
140-
[](const pir::Attribute& attr) {
141-
return T{
142-
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data()};
143-
}},
144-
{AttrType::PLACE,
122+
{AttrType::FLOAT,
145123
[](const pir::Attribute& attr) {
146-
return T{attr.dyn_cast<paddle::dialect::PlaceAttribute>().data()};
124+
return T{attr.dyn_cast<pir::FloatAttribute>().data()};
147125
}},
148-
{AttrType::TENSOR_NAME,
126+
{AttrType::DOUBLE,
149127
[](const pir::Attribute& attr) {
150-
return T{attr.dyn_cast<pir::TensorNameAttribute>().data()};
128+
return T{attr.dyn_cast<pir::DoubleAttribute>().data()};
151129
}},
152130
{AttrType::ARRAY,
153131
[](const pir::Attribute& attr) {
@@ -211,7 +189,33 @@ static std::function<T(const pir::Attribute& attr)> GetAttrCast(
211189
"vector."));
212190
}
213191
}},
214-
};
192+
{AttrType::STRING,
193+
[](const pir::Attribute& attr) {
194+
return T{attr.dyn_cast<pir::StrAttribute>().AsString()};
195+
}},
196+
197+
{AttrType::TENSOR_NAME,
198+
[](const pir::Attribute& attr) {
199+
return T{attr.dyn_cast<pir::TensorNameAttribute>().data()};
200+
}},
201+
{AttrType::DATA_TYPE,
202+
[](const pir::Attribute& attr) {
203+
return T{
204+
attr.dyn_cast<paddle::dialect::DataTypeAttribute>().data()};
205+
}},
206+
{AttrType::INT_ARRAY,
207+
[](const pir::Attribute& attr) {
208+
return T{attr.dyn_cast<paddle::dialect::IntArrayAttribute>()
209+
.data()
210+
.GetData()};
211+
}},
212+
{AttrType::PLACE,
213+
[](const pir::Attribute& attr) {
214+
return T{attr.dyn_cast<paddle::dialect::PlaceAttribute>().data()};
215+
}},
216+
{AttrType::TensorDist, [](const pir::Attribute& attr) {
217+
return T{attr.dyn_cast<paddle::dialect::TensorDistAttribute>()};
218+
}}};
215219
return kAttrCastMap[attr_type];
216220
}
217221

paddle/fluid/pybind/auto_parallel_py.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -803,9 +803,13 @@ void BindAutoParallel(py::module *m) {
803803

804804
m->def(
805805
"dtensor_to_local",
806-
[](py::handle py_tensor) {
806+
[](py::handle py_tensor,
807+
py::handle py_process_mesh,
808+
py::handle py_placements) {
807809
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
808-
return dtensor_to_local_ad_function(tensor);
810+
auto process_mesh = CastPyArg2ProcessMesh(py_process_mesh.ptr(), 1);
811+
auto placements = CastPyArg2VectorOfPlacement(py_placements.ptr(), 2);
812+
return dtensor_to_local_ad_function(tensor, process_mesh, placements);
809813
},
810814
py::return_value_policy::reference);
811815

0 commit comments

Comments
 (0)