Skip to content

Commit a882ae8

Browse files
committed
support create compute graph in local view.
1 parent 9015a4d commit a882ae8

File tree

11 files changed

+409
-20
lines changed

11 files changed

+409
-20
lines changed

paddle/fluid/pir/dialect/distributed/ir/dist_api.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,30 @@ pir::Value reshard(const pir::Value& x,
6363
return reshard_op.result(0);
6464
}
6565

66+
pir::Value dtensor_from_local(
67+
const pir::Value& x,
68+
const phi::distributed::ProcessMesh& process_mesh,
69+
const std::vector<int64_t>& dims_mapping,
70+
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
71+
pir::IrContext* ctx = pir::IrContext::Instance();
72+
TensorDistAttribute tensor_dist_attr =
73+
TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status);
74+
return dtensor_from_local(x, tensor_dist_attr);
75+
}
76+
77+
pir::Value dtensor_from_local(const pir::Value& x,
78+
const TensorDistAttribute& tensor_dist_attr) {
79+
return ApiBuilder::Instance()
80+
.GetBuilder()
81+
->Build<DtensorFromLocalOp>(x, tensor_dist_attr)
82+
.result(0);
83+
}
84+
85+
pir::Value dtensor_to_local(const pir::Value& x) {
86+
return ApiBuilder::Instance().GetBuilder()->Build<DtensorToLocalOp>(x).result(
87+
0);
88+
}
89+
6690
std::vector<pir::Value> moe_sub_mesh_tensors(
6791
const pir::Value& input,
6892
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,

paddle/fluid/pir/dialect/distributed/ir/dist_api.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ pir::Value reshard(
4141
pir::Value reshard(const pir::Value& x,
4242
const TensorDistAttribute& tensor_dist_attr);
4343

44+
pir::Value dtensor_from_local(
45+
const pir::Value& x,
46+
const phi::distributed::ProcessMesh& process_mesh,
47+
const std::vector<int64_t>& dims_mapping,
48+
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {});
49+
pir::Value dtensor_from_local(const pir::Value& x,
50+
const TensorDistAttribute& tensor_dist_attr);
51+
52+
pir::Value dtensor_to_local(const pir::Value& x);
53+
4454
std::vector<pir::Value> moe_sub_mesh_tensors(
4555
const pir::Value& input,
4656
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,

paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ void DistDialect::initialize() {
3636
RegisterTypes<DistDenseTensorType>();
3737
RegisterOps<ShardTensorOp,
3838
ReshardOp,
39+
DtensorFromLocalOp,
40+
DtensorToLocalOp,
3941
MoESubMeshTensorsOp,
4042
MoEGlobalMeshTensorOp>();
4143
}

paddle/fluid/pir/dialect/distributed/ir/dist_op.cc

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,198 @@ void ReshardOp::Build(pir::Builder& builder,
326326
::pir::PassStopGradientsDefaultly(argument);
327327
}
328328

329+
void DtensorFromLocalOp::Build(pir::Builder& builder,
330+
pir::OperationArgument& argument,
331+
pir::Value input,
332+
TensorDistAttribute tensor_dist_attr) {
333+
VLOG(4) << "Start build DtensorFromLocalOp";
334+
335+
paddle::dialect::DenseTensorType local_tensor_type;
336+
if (input.type().isa<paddle::dialect::DenseTensorType>()) {
337+
local_tensor_type =
338+
input.type().dyn_cast<paddle::dialect::DenseTensorType>();
339+
} else {
340+
PADDLE_THROW(common::errors::Unimplemented(
341+
"Only support paddle::dialect::DenseTensorType"));
342+
}
343+
344+
VLOG(4) << "Builder construction inputs";
345+
argument.AddInput(input);
346+
347+
VLOG(4) << "Builder construction attributes";
348+
349+
VLOG(4) << "Builder construction outputs";
350+
351+
auto global_ddim =
352+
InferGlobalDDim(local_tensor_type.dims(), tensor_dist_attr);
353+
auto global_tensor =
354+
dialect::DenseTensorType::get(pir::IrContext::Instance(),
355+
local_tensor_type.dtype(),
356+
global_ddim,
357+
local_tensor_type.data_layout(),
358+
local_tensor_type.lod(),
359+
local_tensor_type.offset());
360+
361+
pir::Type out_dist_tensor_type =
362+
paddle::dialect::DistDenseTensorType::get(pir::IrContext::Instance(),
363+
global_tensor,
364+
tensor_dist_attr,
365+
local_tensor_type.dims());
366+
argument.AddOutput(out_dist_tensor_type);
367+
::pir::PassStopGradientsDefaultly(argument);
368+
}
369+
370+
OpInfoTuple DtensorFromLocalOp::GetOpInfo() {
371+
return OpInfoTuple({OpInputInfo()},
372+
{},
373+
{OpOutputInfo()},
374+
OpRunTimeInfo(),
375+
"dtensor_from_local");
376+
}
377+
std::vector<std::vector<pir::Value>> DtensorFromLocalOp::Vjp(
378+
pir::Operation* op,
379+
const std::vector<std::vector<pir::Value>>& inputs,
380+
const std::vector<std::vector<pir::Value>>& outputs,
381+
const std::vector<std::vector<pir::Value>>& out_grads,
382+
const std::vector<std::vector<bool>>& stop_gradients) {
383+
VLOG(6) << "Start call vjp for dtensor_from_local op.";
384+
PADDLE_ENFORCE_EQ(inputs.size(),
385+
1,
386+
common::errors::InvalidArgument(
387+
"dtensor_from_local op's inputs' size should be 1"));
388+
PADDLE_ENFORCE_EQ(
389+
inputs[0].size(),
390+
1,
391+
common::errors::InvalidArgument(
392+
"dtensor_from_local op's inputs[0]'s size should be 1"));
393+
394+
PADDLE_ENFORCE_EQ(outputs.size(),
395+
1,
396+
common::errors::InvalidArgument(
397+
"dtensor_from_local op's outputs' size should be 1"));
398+
PADDLE_ENFORCE_EQ(
399+
outputs[0].size(),
400+
1,
401+
common::errors::InvalidArgument(
402+
"dtensor_from_local op's outputs[0]'s size should be 1"));
403+
auto dist_type = outputs[0][0].type().dyn_cast<DistTypeInterface>();
404+
405+
PADDLE_ENFORCE_NOT_NULL(
406+
dist_type,
407+
common::errors::InvalidArgument("Currently, dtensor_from_local op's "
408+
"outputs type must be dist type."));
409+
410+
PADDLE_ENFORCE_EQ(
411+
out_grads.size(),
412+
1,
413+
common::errors::InvalidArgument(
414+
"dtensor_from_local op's outputs grad size should be 1"));
415+
416+
PADDLE_ENFORCE_EQ(
417+
out_grads[0].size(),
418+
1,
419+
common::errors::InvalidArgument(
420+
"dtensor_from_local op's outputs grad[0] size should be 1"));
421+
422+
auto& builder = *ApiBuilder::Instance().GetBuilder();
423+
424+
auto out_grad = out_grads[0][0];
425+
426+
if (out_grad.type() != outputs[0][0].type()) {
427+
out_grad = builder.Build<ReshardOp>(out_grad, dist_type.tensor_dist_attr())
428+
->result(0);
429+
}
430+
431+
auto grad_op = builder.Build<DtensorToLocalOp>(out_grad);
432+
433+
VLOG(6) << "End call vjp for dtensor_from_local op.";
434+
435+
return {std::vector<pir::Value>{grad_op->result(0)}};
436+
}
437+
438+
void DtensorToLocalOp::Build(pir::Builder& builder,
439+
pir::OperationArgument& argument,
440+
pir::Value input) {
441+
VLOG(4) << "Start build DtensorToLocalOp";
442+
443+
VLOG(4) << "Builder construction inputs";
444+
argument.AddInput(input);
445+
446+
VLOG(4) << "Builder construction attributes";
447+
448+
VLOG(4) << "Builder construction outputs";
449+
450+
auto dist_type = input.type().dyn_cast<DistTypeInterface>();
451+
if (!dist_type) {
452+
PADDLE_THROW(common::errors::Unimplemented(
453+
"The input of DtensorToLocalOp must be dist type."));
454+
}
455+
456+
argument.AddOutput(dist_type.local_type());
457+
::pir::PassStopGradientsDefaultly(argument);
458+
}
459+
460+
OpInfoTuple DtensorToLocalOp::GetOpInfo() {
461+
return OpInfoTuple({OpInputInfo()},
462+
{},
463+
{OpOutputInfo()},
464+
OpRunTimeInfo(),
465+
"dtensor_to_local");
466+
}
467+
468+
std::vector<std::vector<pir::Value>> DtensorToLocalOp::Vjp(
469+
pir::Operation* op,
470+
const std::vector<std::vector<pir::Value>>& inputs,
471+
const std::vector<std::vector<pir::Value>>& outputs,
472+
const std::vector<std::vector<pir::Value>>& out_grads,
473+
const std::vector<std::vector<bool>>& stop_gradients) {
474+
VLOG(6) << "Start call vjp for dtensor_to_local op.";
475+
PADDLE_ENFORCE_EQ(inputs.size(),
476+
1,
477+
common::errors::InvalidArgument(
478+
"dtensor_to_local op's inputs' size should be 1"));
479+
PADDLE_ENFORCE_EQ(inputs[0].size(),
480+
1,
481+
common::errors::InvalidArgument(
482+
"dtensor_to_local op's inputs[0]'s size should be 1"));
483+
484+
PADDLE_ENFORCE_EQ(outputs.size(),
485+
1,
486+
common::errors::InvalidArgument(
487+
"dtensor_to_local op's outputs' size should be 1"));
488+
PADDLE_ENFORCE_EQ(outputs[0].size(),
489+
1,
490+
common::errors::InvalidArgument(
491+
"dtensor_to_local op's outputs[0]'s size should be 1"));
492+
auto dist_type = inputs[0][0].type().dyn_cast<DistTypeInterface>();
493+
494+
PADDLE_ENFORCE_NOT_NULL(
495+
dist_type,
496+
common::errors::InvalidArgument(
497+
"Currently, dtensor_to_local op's inputs type must be dist type."));
498+
499+
PADDLE_ENFORCE_EQ(
500+
out_grads.size(),
501+
1,
502+
common::errors::InvalidArgument(
503+
"dtensor_from_local op's outputs grad size should be 1"));
504+
505+
PADDLE_ENFORCE_EQ(
506+
out_grads[0].size(),
507+
1,
508+
common::errors::InvalidArgument(
509+
"dtensor_from_local op's outputs grad[0] size should be 1"));
510+
511+
auto& builder = *ApiBuilder::Instance().GetBuilder();
512+
513+
auto grad_op = builder.Build<DtensorFromLocalOp>(
514+
out_grads[0][0], dist_type.tensor_dist_attr());
515+
516+
VLOG(6) << "End call vjp for dtensor_from_local op.";
517+
518+
return {std::vector<pir::Value>{grad_op->result(0)}};
519+
}
520+
329521
TEST_API void paddle::dialect::MoESubMeshTensorsOp::Build(
330522
pir::Builder& builder,
331523
pir::OperationArgument& argument,
@@ -699,5 +891,7 @@ std::vector<std::vector<pir::Value>> MoEGlobalMeshTensorOp::Vjp(
699891

700892
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp)
701893
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
894+
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorFromLocalOp)
895+
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorToLocalOp)
702896
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MoESubMeshTensorsOp)
703897
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MoEGlobalMeshTensorOp)

paddle/fluid/pir/dialect/distributed/ir/dist_op.h

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,50 @@ class ReshardOp : public pir::Op<ReshardOp, VjpInterface, OpYamlInfoInterface> {
6363
void VerifySig();
6464
};
6565

66-
class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp, VjpInterface> {
66+
class DtensorFromLocalOp
67+
: public pir::Op<DtensorFromLocalOp, VjpInterface, OpYamlInfoInterface> {
68+
public:
69+
using Op::Op;
70+
static const char* name() { return "dist_op.dtensor_from_local"; }
71+
static constexpr const char** attributes_name = nullptr;
72+
static constexpr uint32_t attributes_num = 0;
73+
TEST_API static void Build(pir::Builder& builder, // NOLINT
74+
pir::OperationArgument& argument, // NOLINT
75+
pir::Value input,
76+
TensorDistAttribute tensor_dist_attr);
77+
78+
static OpInfoTuple GetOpInfo();
79+
static std::vector<std::vector<pir::Value>> Vjp(
80+
pir::Operation* op,
81+
const std::vector<std::vector<pir::Value>>& inputs_,
82+
const std::vector<std::vector<pir::Value>>& outputs,
83+
const std::vector<std::vector<pir::Value>>& out_grads,
84+
const std::vector<std::vector<bool>>& stop_gradients);
85+
};
86+
87+
class DtensorToLocalOp
88+
: public pir::Op<DtensorToLocalOp, VjpInterface, OpYamlInfoInterface> {
89+
public:
90+
using Op::Op;
91+
static const char* name() { return "dist_op.dtensor_to_local"; }
92+
static constexpr const char** attributes_name = nullptr;
93+
static constexpr uint32_t attributes_num = 0;
94+
TEST_API static void Build(pir::Builder& builder, // NOLINT
95+
pir::OperationArgument& argument, // NOLINT
96+
pir::Value input);
97+
98+
static OpInfoTuple GetOpInfo();
99+
static std::vector<std::vector<pir::Value>> Vjp(
100+
pir::Operation* op,
101+
const std::vector<std::vector<pir::Value>>& inputs_,
102+
const std::vector<std::vector<pir::Value>>& outputs,
103+
const std::vector<std::vector<pir::Value>>& out_grads,
104+
const std::vector<std::vector<bool>>& stop_gradients);
105+
106+
// void VerifySig();
107+
};
108+
109+
class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp> {
67110
public:
68111
using Op::Op;
69112
static const char* name() { return "dist_op.moe_sub_mesh_tensors"; }
@@ -120,5 +163,7 @@ class MoEGlobalMeshTensorOp
120163

121164
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp)
122165
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
166+
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorFromLocalOp)
167+
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorToLocalOp)
123168
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::MoESubMeshTensorsOp)
124169
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::MoEGlobalMeshTensorOp)

paddle/fluid/pir/dialect/distributed/ir/dist_type.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,30 @@ common::DDim InferLocalDDim(const common::DDim& global_ddim,
6363
return local_ddim;
6464
}
6565

66+
common::DDim InferGlobalDDim(const common::DDim& local_ddim,
67+
TensorDistAttribute dist_attr) {
68+
if (local_ddim.size() == -1 || local_ddim.size() == 0) {
69+
return local_ddim;
70+
}
71+
const ProcessMeshAttribute& mesh_attr = dist_attr.process_mesh_attr();
72+
auto& mesh_dim = mesh_attr.shape();
73+
auto& dim_mapping = dist_attr.dims_mapping();
74+
PADDLE_ENFORCE_EQ(local_ddim.size(),
75+
dim_mapping.size(),
76+
::common::errors::PreconditionNotMet(
77+
"The local ddim size must equal to dim_mapping's "
78+
"size, but bot %d vs %d",
79+
local_ddim.size(),
80+
dim_mapping.size()));
81+
common::DDim global_ddim(local_ddim);
82+
for (size_t i = 0; i < dim_mapping.size(); ++i) {
83+
if (dim_mapping[i] != -1) {
84+
global_ddim[i] = local_ddim[i] * mesh_dim.at(dim_mapping[i]);
85+
}
86+
}
87+
return global_ddim;
88+
}
89+
6690
pir::DenseTensorType DistDenseTensorType::local_type() const {
6791
return pir::DenseTensorType::get(pir::IrContext::Instance(),
6892
dtype(),

paddle/fluid/pir/dialect/distributed/ir/dist_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class DistDenseTensorTypeStorage;
2626

2727
common::DDim InferLocalDDim(const common::DDim& global_ddim,
2828
TensorDistAttribute dist_attr);
29+
30+
common::DDim InferGlobalDDim(const common::DDim& local_ddim,
31+
TensorDistAttribute dist_attr);
2932
class DistDenseTensorType
3033
: public pir::Type::TypeBase<DistDenseTensorType,
3134
pir::Type,

0 commit comments

Comments
 (0)