Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ pir::Value reshard(const pir::Value& x,
return reshard_op.result(0);
}

pir::Value dtensor_from_local(
const pir::Value& x,
const phi::distributed::ProcessMesh& process_mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status) {
pir::IrContext* ctx = pir::IrContext::Instance();
TensorDistAttribute tensor_dist_attr =
TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status);
return dtensor_from_local(x, tensor_dist_attr);
}

pir::Value dtensor_from_local(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr) {
return ApiBuilder::Instance()
.GetBuilder()
->Build<DtensorFromLocalOp>(x, tensor_dist_attr)
.result(0);
}

pir::Value dtensor_to_local(const pir::Value& x) {
return ApiBuilder::Instance().GetBuilder()->Build<DtensorToLocalOp>(x).result(
0);
}

std::vector<pir::Value> moe_sub_mesh_tensors(
const pir::Value& input,
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ pir::Value reshard(
pir::Value reshard(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr);

pir::Value dtensor_from_local(
const pir::Value& x,
const phi::distributed::ProcessMesh& process_mesh,
const std::vector<int64_t>& dims_mapping,
const flat_hash_map<int64_t, phi::ReduceType>& partial_status = {});
pir::Value dtensor_from_local(const pir::Value& x,
const TensorDistAttribute& tensor_dist_attr);

pir::Value dtensor_to_local(const pir::Value& x);

std::vector<pir::Value> moe_sub_mesh_tensors(
const pir::Value& input,
const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void DistDialect::initialize() {
RegisterTypes<DistDenseTensorType>();
RegisterOps<ShardTensorOp,
ReshardOp,
DtensorFromLocalOp,
DtensorToLocalOp,
MoESubMeshTensorsOp,
MoEGlobalMeshTensorOp>();
}
Expand Down
194 changes: 194 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,198 @@ void ReshardOp::Build(pir::Builder& builder,
::pir::PassStopGradientsDefaultly(argument);
}

void DtensorFromLocalOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value input,
TensorDistAttribute tensor_dist_attr) {
VLOG(4) << "Start build DtensorFromLocalOp";

paddle::dialect::DenseTensorType local_tensor_type;
if (input.type().isa<paddle::dialect::DenseTensorType>()) {
local_tensor_type =
input.type().dyn_cast<paddle::dialect::DenseTensorType>();
} else {
PADDLE_THROW(common::errors::Unimplemented(
"Only support paddle::dialect::DenseTensorType"));
}

VLOG(4) << "Builder construction inputs";
argument.AddInput(input);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";

auto global_ddim =
InferGlobalDDim(local_tensor_type.dims(), tensor_dist_attr);
auto global_tensor =
dialect::DenseTensorType::get(pir::IrContext::Instance(),
local_tensor_type.dtype(),
global_ddim,
local_tensor_type.data_layout(),
local_tensor_type.lod(),
local_tensor_type.offset());

pir::Type out_dist_tensor_type =
paddle::dialect::DistDenseTensorType::get(pir::IrContext::Instance(),
global_tensor,
tensor_dist_attr,
local_tensor_type.dims());
argument.AddOutput(out_dist_tensor_type);
::pir::PassStopGradientsDefaultly(argument);
}

OpInfoTuple DtensorFromLocalOp::GetOpInfo() {
return OpInfoTuple({OpInputInfo()},
{},
{OpOutputInfo()},
OpRunTimeInfo(),
"dtensor_from_local");
}
std::vector<std::vector<pir::Value>> DtensorFromLocalOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
VLOG(6) << "Start call vjp for dtensor_from_local op.";
PADDLE_ENFORCE_EQ(inputs.size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's inputs' size should be 1"));
PADDLE_ENFORCE_EQ(
inputs[0].size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's inputs[0]'s size should be 1"));

PADDLE_ENFORCE_EQ(outputs.size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's outputs' size should be 1"));
PADDLE_ENFORCE_EQ(
outputs[0].size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's outputs[0]'s size should be 1"));
auto dist_type = outputs[0][0].type().dyn_cast<DistTypeInterface>();

PADDLE_ENFORCE_NOT_NULL(
dist_type,
common::errors::InvalidArgument("Currently, dtensor_from_local op's "
"outputs type must be dist type."));

PADDLE_ENFORCE_EQ(
out_grads.size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's outputs grad size should be 1"));

PADDLE_ENFORCE_EQ(
out_grads[0].size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's outputs grad[0] size should be 1"));

auto& builder = *ApiBuilder::Instance().GetBuilder();

auto out_grad = out_grads[0][0];

if (out_grad.type() != outputs[0][0].type()) {
out_grad = builder.Build<ReshardOp>(out_grad, dist_type.tensor_dist_attr())
->result(0);
}

auto grad_op = builder.Build<DtensorToLocalOp>(out_grad);

VLOG(6) << "End call vjp for dtensor_from_local op.";

return {std::vector<pir::Value>{grad_op->result(0)}};
}

void DtensorToLocalOp::Build(pir::Builder& builder,
pir::OperationArgument& argument,
pir::Value input) {
VLOG(4) << "Start build DtensorToLocalOp";

VLOG(4) << "Builder construction inputs";
argument.AddInput(input);

VLOG(4) << "Builder construction attributes";

VLOG(4) << "Builder construction outputs";

auto dist_type = input.type().dyn_cast<DistTypeInterface>();
if (!dist_type) {
PADDLE_THROW(common::errors::Unimplemented(
"The input of DtensorToLocalOp must be dist type."));
}

argument.AddOutput(dist_type.local_type());
::pir::PassStopGradientsDefaultly(argument);
}

OpInfoTuple DtensorToLocalOp::GetOpInfo() {
return OpInfoTuple({OpInputInfo()},
{},
{OpOutputInfo()},
OpRunTimeInfo(),
"dtensor_to_local");
}

std::vector<std::vector<pir::Value>> DtensorToLocalOp::Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
VLOG(6) << "Start call vjp for dtensor_to_local op.";
PADDLE_ENFORCE_EQ(inputs.size(),
1,
common::errors::InvalidArgument(
"dtensor_to_local op's inputs' size should be 1"));
PADDLE_ENFORCE_EQ(inputs[0].size(),
1,
common::errors::InvalidArgument(
"dtensor_to_local op's inputs[0]'s size should be 1"));

PADDLE_ENFORCE_EQ(outputs.size(),
1,
common::errors::InvalidArgument(
"dtensor_to_local op's outputs' size should be 1"));
PADDLE_ENFORCE_EQ(outputs[0].size(),
1,
common::errors::InvalidArgument(
"dtensor_to_local op's outputs[0]'s size should be 1"));
auto dist_type = inputs[0][0].type().dyn_cast<DistTypeInterface>();

PADDLE_ENFORCE_NOT_NULL(
dist_type,
common::errors::InvalidArgument(
"Currently, dtensor_to_local op's inputs type must be dist type."));

PADDLE_ENFORCE_EQ(
out_grads.size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's outputs grad size should be 1"));

PADDLE_ENFORCE_EQ(
out_grads[0].size(),
1,
common::errors::InvalidArgument(
"dtensor_from_local op's outputs grad[0] size should be 1"));

auto& builder = *ApiBuilder::Instance().GetBuilder();

auto grad_op = builder.Build<DtensorFromLocalOp>(
out_grads[0][0], dist_type.tensor_dist_attr());

VLOG(6) << "End call vjp for dtensor_from_local op.";

return {std::vector<pir::Value>{grad_op->result(0)}};
}

TEST_API void paddle::dialect::MoESubMeshTensorsOp::Build(
pir::Builder& builder,
pir::OperationArgument& argument,
Expand Down Expand Up @@ -699,5 +891,7 @@ std::vector<std::vector<pir::Value>> MoEGlobalMeshTensorOp::Vjp(

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorFromLocalOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorToLocalOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MoESubMeshTensorsOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::MoEGlobalMeshTensorOp)
47 changes: 46 additions & 1 deletion paddle/fluid/pir/dialect/distributed/ir/dist_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,50 @@ class ReshardOp : public pir::Op<ReshardOp, VjpInterface, OpYamlInfoInterface> {
void VerifySig();
};

class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp, VjpInterface> {
class DtensorFromLocalOp
: public pir::Op<DtensorFromLocalOp, VjpInterface, OpYamlInfoInterface> {
public:
using Op::Op;
static const char* name() { return "dist_op.dtensor_from_local"; }
static constexpr const char** attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
TEST_API static void Build(pir::Builder& builder, // NOLINT
pir::OperationArgument& argument, // NOLINT
pir::Value input,
TensorDistAttribute tensor_dist_attr);

static OpInfoTuple GetOpInfo();
static std::vector<std::vector<pir::Value>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);
};

class DtensorToLocalOp
: public pir::Op<DtensorToLocalOp, VjpInterface, OpYamlInfoInterface> {
public:
using Op::Op;
static const char* name() { return "dist_op.dtensor_to_local"; }
static constexpr const char** attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
TEST_API static void Build(pir::Builder& builder, // NOLINT
pir::OperationArgument& argument, // NOLINT
pir::Value input);

static OpInfoTuple GetOpInfo();
static std::vector<std::vector<pir::Value>> Vjp(
pir::Operation* op,
const std::vector<std::vector<pir::Value>>& inputs_,
const std::vector<std::vector<pir::Value>>& outputs,
const std::vector<std::vector<pir::Value>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients);

// void VerifySig();
};

class MoESubMeshTensorsOp : public pir::Op<MoESubMeshTensorsOp> {
public:
using Op::Op;
static const char* name() { return "dist_op.moe_sub_mesh_tensors"; }
Expand Down Expand Up @@ -120,5 +163,7 @@ class MoEGlobalMeshTensorOp

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReshardOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorFromLocalOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DtensorToLocalOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::MoESubMeshTensorsOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::MoEGlobalMeshTensorOp)
24 changes: 24 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ common::DDim InferLocalDDim(const common::DDim& global_ddim,
return local_ddim;
}

common::DDim InferGlobalDDim(const common::DDim& local_ddim,
TensorDistAttribute dist_attr) {
if (local_ddim.size() == -1 || local_ddim.size() == 0) {
return local_ddim;
}
const ProcessMeshAttribute& mesh_attr = dist_attr.process_mesh_attr();
auto& mesh_dim = mesh_attr.shape();
auto& dim_mapping = dist_attr.dims_mapping();
PADDLE_ENFORCE_EQ(local_ddim.size(),
dim_mapping.size(),
::common::errors::PreconditionNotMet(
"The local ddim size must equal to dim_mapping's "
"size, but bot %d vs %d",
local_ddim.size(),
dim_mapping.size()));
common::DDim global_ddim(local_ddim);
for (size_t i = 0; i < dim_mapping.size(); ++i) {
if (dim_mapping[i] != -1) {
global_ddim[i] = local_ddim[i] * mesh_dim.at(dim_mapping[i]);
}
}
return global_ddim;
}

pir::DenseTensorType DistDenseTensorType::local_type() const {
return pir::DenseTensorType::get(pir::IrContext::Instance(),
dtype(),
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class DistDenseTensorTypeStorage;

common::DDim InferLocalDDim(const common::DDim& global_ddim,
TensorDistAttribute dist_attr);

common::DDim InferGlobalDDim(const common::DDim& local_ddim,
TensorDistAttribute dist_attr);
class DistDenseTensorType
: public pir::Type::TypeBase<DistDenseTensorType,
pir::Type,
Expand Down
Loading
Loading