Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& process_ids,
const std::vector<std::string>& dim_names);

static std::string name() { return "a_process_mesh"; }
};

class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
Expand Down Expand Up @@ -98,6 +100,8 @@ class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
dims_mapping,
partial_status);
}

static std::string name() { return "a_tensor_dist"; }
};

class OperationDistAttribute : public pir::AttrBase<OperationDistAttribute,
Expand Down Expand Up @@ -128,6 +132,8 @@ class OperationDistAttribute : public pir::AttrBase<OperationDistAttribute,
const std::vector<Attribute>& results) {
return get(ctx, ProcessMeshAttribute::get(ctx, mesh), operands, results);
}

static std::string name() { return "a_op_dist"; }
};

} // namespace dialect
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/distributed/ir/dist_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DistDenseTensorType
using Base::Base;
using LoD = pir::DenseTensorTypeStorage::LoD;

static std::string name() { return "t_dist_dtensor"; }

pir::DenseTensorType dense_tensor_type() const;
TensorDistAttribute tensor_dist_attr() const;
const common::DDim& global_ddim() const { return dense_tensor_type().dims(); }
Expand Down
144 changes: 142 additions & 2 deletions paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@

#include "paddle/common/layout.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/serialize_deserialize/include/schema.h"
#include "paddle/fluid/pir/serialize_deserialize/include/third_party.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/utils/flat_hash_map.h"

namespace pir {
#define DECOMPRESS_DIALECT_ID(name) \
Expand Down Expand Up @@ -54,6 +57,14 @@ class AttrTypeReader {
static pir::Type ReadPaddleOperatorType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx);

static pir::Type ReadPaddleDistType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx);

static pir::Attribute ReadPaddleDistAttr(const std::string attr_name,
Json* attr_json,
pir::IrContext* ctx);
};

template <typename T>
Expand Down Expand Up @@ -180,13 +191,16 @@ pir::Type parseType(Json* type_json) {
}

pir::IrContext* ctx = pir::IrContext::Instance();
std::pair<std::string, std::string> name = getContentSplitByDot(type_name);
std::pair<std::string, std::string> name = GetContentSplitByDot(type_name);

if (DECOMPRESS_DIALECT_ID(name.first) == pir::BuiltinDialect::name()) {
return AttrTypeReader::ReadBuiltInType(name.second, type_json, ctx);
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
paddle::dialect::OperatorDialect::name()) {
return AttrTypeReader::ReadPaddleOperatorType(name.second, type_json, ctx);
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
paddle::dialect::DistDialect::name()) {
return AttrTypeReader::ReadPaddleDistType(name.second, type_json, ctx);
} else {
PADDLE_ENFORCE(
false,
Expand All @@ -209,13 +223,16 @@ pir::TypeAttribute deserializeAttrFromJson<pir::TypeAttribute, pir::Type>(
pir::Attribute parseAttr(Json* attr_json) {
std::string attr_name = attr_json->at(ID).template get<std::string>();
pir::IrContext* ctx = pir::IrContext::Instance();
std::pair<std::string, std::string> name = getContentSplitByDot(attr_name);
std::pair<std::string, std::string> name = GetContentSplitByDot(attr_name);

if (DECOMPRESS_DIALECT_ID(name.first) == pir::BuiltinDialect::name()) {
return AttrTypeReader::ReadBuiltInAttr(name.second, attr_json, ctx);
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
paddle::dialect::OperatorDialect::name()) {
return AttrTypeReader::ReadPaddleOperatorAttr(name.second, attr_json, ctx);
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
paddle::dialect::DistDialect::name()) {
return AttrTypeReader::ReadPaddleDistAttr(name.second, attr_json, ctx);
} else {
PADDLE_ENFORCE(
false,
Expand All @@ -228,6 +245,68 @@ pir::Attribute parseAttr(Json* attr_json) {
return pir::Attribute();
}

// ProcessMesh includes: std::vector<int64_t>& shape, std::vector<int64_t>&
// process_ids, std::vector<std::string>& dim_names
paddle::dialect::ProcessMeshAttribute deserializeProcessMeshAttr(
Json* attr_json, pir::IrContext* ctx) {
Json data_json = attr_json->at(DATA);
VLOG(8) << "deserialize shape";
std::vector<int64_t> shape =
data_json.at(0).template get<std::vector<int64_t>>();
VLOG(8) << "deserialize process_ids";
std::vector<int64_t> process_ids =
data_json.at(1).template get<std::vector<int64_t>>();
VLOG(8) << "deserialize dim_names";
std::vector<std::string> dim_names =
data_json.at(2).template get<std::vector<std::string>>();
return paddle::dialect::ProcessMeshAttribute::get(
ctx, shape, process_ids, dim_names);
}

// TensorDistAttribute includes: ProcessMeshAttribute mesh_attr,
// std::vector<int64_t> dims_mapping, flat_hash_map<int64_t, phi::ReduceType>
// partial_status;
paddle::dialect::TensorDistAttribute deserializeTensorDistAttr(
Json* attr_json, pir::IrContext* ctx) {
Json data_json = attr_json->at(DATA);
VLOG(8) << "deserialize ProcessMeshAttr";
paddle::dialect::ProcessMeshAttribute mesh =
deserializeProcessMeshAttr(&(data_json.at(0)), ctx);
VLOG(8) << "deserialize dims_mapping";
std::vector<int64_t> dims_mapping =
data_json.at(1).template get<std::vector<int64_t>>();
VLOG(8) << "deserialize partial_status";
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status;
Json map_json = data_json.at(2);
for (const auto& item : map_json) {
partial_status[item[0]] = static_cast<phi::ReduceType>(item[1]);
}
return paddle::dialect::TensorDistAttribute::get(
ctx, mesh, dims_mapping, partial_status);
}

// OperationDistAttribute includes: ProcessMeshAttribute mesh_attr,
// std::vector<pir::Attribute> operands, std::vector<pir::Attribute> results;
paddle::dialect::OperationDistAttribute deserializeOperationDistAttr(
Json* attr_json, pir::IrContext* ctx) {
Json data_json = attr_json->at(DATA);
paddle::dialect::ProcessMeshAttribute mesh =
deserializeProcessMeshAttr(&(data_json.at(0)), ctx);
std::vector<Attribute> operands;
Json operands_json = data_json.at(1);
for (auto& item : operands_json) {
operands.push_back(parseAttr(&item));
}

std::vector<Attribute> results;
Json results_json = data_json.at(2);
for (auto& item : results_json) {
results.push_back(parseAttr(&item));
}
return paddle::dialect::OperationDistAttribute::get(
ctx, mesh, operands, results);
}

pir::Attribute AttrTypeReader::ReadBuiltInAttr(const std::string attr_name,
Json* attr_json,
pir::IrContext* ctx) {
Expand Down Expand Up @@ -319,6 +398,27 @@ pir::Attribute AttrTypeReader::ReadPaddleOperatorAttr(
return pir::Attribute();
}

pir::Attribute AttrTypeReader::ReadPaddleDistAttr(const std::string attr_name,
Json* attr_json,
pir::IrContext* ctx) {
if (attr_name == paddle::dialect::ProcessMeshAttribute::name()) {
VLOG(8) << "Parse ProcessMeshAttribute .";
return pir::deserializeProcessMeshAttr(attr_json, ctx);
} else if (attr_name == paddle::dialect::TensorDistAttribute::name()) {
VLOG(8) << "Parse TensorDistAttribute .";
return pir::deserializeTensorDistAttr(attr_json, ctx);
} else if (attr_name == paddle::dialect::OperationDistAttribute::name()) {
VLOG(8) << "Parse OperationDistAttribute .";
return pir::deserializeOperationDistAttr(attr_json, ctx);
} else {
PADDLE_ENFORCE(
false,
phi::errors::InvalidArgument(
"Unknown Attr %s for parse paddle dist dialect attr", attr_name));
}
return pir::Attribute();
}

template <typename T>
T deserializeTypeFromJsonIncludeParseType(Json* type_json,
pir::IrContext* ctx) {
Expand Down Expand Up @@ -430,6 +530,30 @@ deserializeTypeFromJsonIncludeParseType<paddle::dialect::SparseCsrTensorType>(
non_zero_elements);
}

template <>
paddle::dialect::DistDenseTensorType
deserializeTypeFromJsonIncludeParseType<paddle::dialect::DistDenseTensorType>(
Json* type_json, pir::IrContext* ctx) {
Json data_json = type_json->at(DATA);

// deserialize pir::DenseTensorType dense_tensor_type;
pir::DenseTensorType dense_tensor_type =
deserializeTypeFromJsonIncludeParseType<pir::DenseTensorType>(
&(data_json.at(0)), ctx);

// deserialize TensorDistAttribute tensor_dist_attr;
paddle::dialect::TensorDistAttribute tensor_dist_attr =
deserializeTensorDistAttr(&(data_json.at(1)), ctx);

// deserialize common::DDim local_ddim;
std::vector<int64_t> dims =
data_json.at(2).template get<std::vector<int64_t>>();
phi::DDim local_ddim = phi::make_ddim(dims);

return paddle::dialect::DistDenseTensorType::get(
ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
}

pir::Type AttrTypeReader::ReadBuiltInType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx) {
Expand Down Expand Up @@ -516,4 +640,20 @@ pir::Type AttrTypeReader::ReadPaddleOperatorType(const std::string type_name,
}
}

pir::Type AttrTypeReader::ReadPaddleDistType(const std::string type_name,
Json* type_json,
pir::IrContext* ctx) {
if (type_name == paddle::dialect::DistDenseTensorType::name()) {
VLOG(8) << "Parse paddle::dialect::DistDenseTensorType ... ";
return pir::deserializeTypeFromJsonIncludeParseType<
paddle::dialect::DistDenseTensorType>(type_json, ctx);
} else {
PADDLE_ENFORCE(false,
phi::errors::InvalidArgument(
"Unknown Type %s for parse paddleoperator dialect type",
type_name));
return pir::Type();
}
}

} // namespace pir
8 changes: 7 additions & 1 deletion paddle/fluid/pir/serialize_deserialize/include/schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "glog/logging.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/pir/include/core/builtin_dialect.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"

namespace pir {
/**
* IMPORTANT!!!
Expand Down Expand Up @@ -57,6 +60,7 @@ namespace pir {
// which is json array with json object(NAME and ATTR_TYPE)
#define ATTRS "A"
#define OPRESULTS_ATTRS "OA"
#define DIST_ATTRS "DA"

// value's key:
// value's type which should be pir::Type's json object(ID or ID and DATA).
Expand All @@ -78,9 +82,11 @@ namespace pir {

#define PARAMETEROP "p"

std::pair<std::string, std::string> getContentSplitByDot(
std::pair<std::string, std::string> GetContentSplitByDot(
const std::string& str);

std::vector<std::string> GetOpDistAttr();

void GetCompressOpName(std::string* op_name);

void GetDecompressOpName(std::string* op_name);
Expand Down
Loading