1919
2020#include " paddle/common/layout.h"
2121#include " paddle/fluid/framework/data_layout.h"
22+ #include " paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
23+ #include " paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
2224#include " paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
2325#include " paddle/fluid/pir/dialect/operator/ir/op_type.h"
2426#include " paddle/fluid/pir/serialize_deserialize/include/schema.h"
2527#include " paddle/fluid/pir/serialize_deserialize/include/third_party.h"
2628#include " paddle/phi/common/data_type.h"
2729#include " paddle/pir/include/core/builtin_attribute.h"
2830#include " paddle/pir/include/core/builtin_type.h"
31+ #include " paddle/utils/flat_hash_map.h"
2932
3033namespace pir {
3134#define DECOMPRESS_DIALECT_ID (name ) \
@@ -54,6 +57,14 @@ class AttrTypeReader {
5457 static pir::Type ReadPaddleOperatorType (const std::string type_name,
5558 Json* type_json,
5659 pir::IrContext* ctx);
60+
61+ static pir::Type ReadPaddleDistType (const std::string type_name,
62+ Json* type_json,
63+ pir::IrContext* ctx);
64+
65+ static pir::Attribute ReadPaddleDistAttr (const std::string attr_name,
66+ Json* attr_json,
67+ pir::IrContext* ctx);
5768};
5869
5970template <typename T>
@@ -180,13 +191,16 @@ pir::Type parseType(Json* type_json) {
180191 }
181192
182193 pir::IrContext* ctx = pir::IrContext::Instance ();
183- std::pair<std::string, std::string> name = getContentSplitByDot (type_name);
194+ std::pair<std::string, std::string> name = GetContentSplitByDot (type_name);
184195
185196 if (DECOMPRESS_DIALECT_ID (name.first ) == pir::BuiltinDialect::name ()) {
186197 return AttrTypeReader::ReadBuiltInType (name.second , type_json, ctx);
187198 } else if (DECOMPRESS_DIALECT_ID (name.first ) ==
188199 paddle::dialect::OperatorDialect::name ()) {
189200 return AttrTypeReader::ReadPaddleOperatorType (name.second , type_json, ctx);
201+ } else if (DECOMPRESS_DIALECT_ID (name.first ) ==
202+ paddle::dialect::DistDialect::name ()) {
203+ return AttrTypeReader::ReadPaddleDistType (name.second , type_json, ctx);
190204 } else {
191205 PADDLE_ENFORCE (
192206 false ,
@@ -209,13 +223,16 @@ pir::TypeAttribute deserializeAttrFromJson<pir::TypeAttribute, pir::Type>(
209223pir::Attribute parseAttr (Json* attr_json) {
210224 std::string attr_name = attr_json->at (ID).template get <std::string>();
211225 pir::IrContext* ctx = pir::IrContext::Instance ();
212- std::pair<std::string, std::string> name = getContentSplitByDot (attr_name);
226+ std::pair<std::string, std::string> name = GetContentSplitByDot (attr_name);
213227
214228 if (DECOMPRESS_DIALECT_ID (name.first ) == pir::BuiltinDialect::name ()) {
215229 return AttrTypeReader::ReadBuiltInAttr (name.second , attr_json, ctx);
216230 } else if (DECOMPRESS_DIALECT_ID (name.first ) ==
217231 paddle::dialect::OperatorDialect::name ()) {
218232 return AttrTypeReader::ReadPaddleOperatorAttr (name.second , attr_json, ctx);
233+ } else if (DECOMPRESS_DIALECT_ID (name.first ) ==
234+ paddle::dialect::DistDialect::name ()) {
235+ return AttrTypeReader::ReadPaddleDistAttr (name.second , attr_json, ctx);
219236 } else {
220237 PADDLE_ENFORCE (
221238 false ,
@@ -228,6 +245,68 @@ pir::Attribute parseAttr(Json* attr_json) {
228245 return pir::Attribute ();
229246}
230247
248+ // ProcessMesh includes: std::vector<int64_t>& shape, std::vector<int64_t>&
249+ // process_ids, std::vector<std::string>& dim_names
250+ paddle::dialect::ProcessMeshAttribute deserializeProcessMeshAttr (
251+ Json* attr_json, pir::IrContext* ctx) {
252+ Json data_json = attr_json->at (DATA);
253+ VLOG (8 ) << " deserialize shape" ;
254+ std::vector<int64_t > shape =
255+ data_json.at (0 ).template get <std::vector<int64_t >>();
256+ VLOG (8 ) << " deserialize process_ids" ;
257+ std::vector<int64_t > process_ids =
258+ data_json.at (1 ).template get <std::vector<int64_t >>();
259+ VLOG (8 ) << " deserialize dim_names" ;
260+ std::vector<std::string> dim_names =
261+ data_json.at (2 ).template get <std::vector<std::string>>();
262+ return paddle::dialect::ProcessMeshAttribute::get (
263+ ctx, shape, process_ids, dim_names);
264+ }
265+
266+ // TensorDistAttribute includes: ProcessMeshAttribute mesh_attr,
267+ // std::vector<int64_t> dims_mapping, flat_hash_map<int64_t, phi::ReduceType>
268+ // partial_status;
269+ paddle::dialect::TensorDistAttribute deserializeTensorDistAttr (
270+ Json* attr_json, pir::IrContext* ctx) {
271+ Json data_json = attr_json->at (DATA);
272+ VLOG (8 ) << " deserialize ProcessMeshAttr" ;
273+ paddle::dialect::ProcessMeshAttribute mesh =
274+ deserializeProcessMeshAttr (&(data_json.at (0 )), ctx);
275+ VLOG (8 ) << " deserialize dims_mapping" ;
276+ std::vector<int64_t > dims_mapping =
277+ data_json.at (1 ).template get <std::vector<int64_t >>();
278+ VLOG (8 ) << " deserialize partial_status" ;
279+ paddle::flat_hash_map<int64_t , phi::ReduceType> partial_status;
280+ Json map_json = data_json.at (2 );
281+ for (const auto & item : map_json) {
282+ partial_status[item[0 ]] = static_cast <phi::ReduceType>(item[1 ]);
283+ }
284+ return paddle::dialect::TensorDistAttribute::get (
285+ ctx, mesh, dims_mapping, partial_status);
286+ }
287+
288+ // OperationDistAttribute includes: ProcessMeshAttribute mesh_attr,
289+ // std::vector<pir::Attribute> operands, std::vector<pir::Attribute> results;
290+ paddle::dialect::OperationDistAttribute deserializeOperationDistAttr (
291+ Json* attr_json, pir::IrContext* ctx) {
292+ Json data_json = attr_json->at (DATA);
293+ paddle::dialect::ProcessMeshAttribute mesh =
294+ deserializeProcessMeshAttr (&(data_json.at (0 )), ctx);
295+ std::vector<Attribute> operands;
296+ Json operands_json = data_json.at (1 );
297+ for (auto & item : operands_json) {
298+ operands.push_back (parseAttr (&item));
299+ }
300+
301+ std::vector<Attribute> results;
302+ Json results_json = data_json.at (2 );
303+ for (auto & item : results_json) {
304+ results.push_back (parseAttr (&item));
305+ }
306+ return paddle::dialect::OperationDistAttribute::get (
307+ ctx, mesh, operands, results);
308+ }
309+
231310pir::Attribute AttrTypeReader::ReadBuiltInAttr (const std::string attr_name,
232311 Json* attr_json,
233312 pir::IrContext* ctx) {
@@ -319,6 +398,27 @@ pir::Attribute AttrTypeReader::ReadPaddleOperatorAttr(
319398 return pir::Attribute ();
320399}
321400
401+ pir::Attribute AttrTypeReader::ReadPaddleDistAttr (const std::string attr_name,
402+ Json* attr_json,
403+ pir::IrContext* ctx) {
404+ if (attr_name == paddle::dialect::ProcessMeshAttribute::name ()) {
405+ VLOG (8 ) << " Parse ProcessMeshAttribute ." ;
406+ return pir::deserializeProcessMeshAttr (attr_json, ctx);
407+ } else if (attr_name == paddle::dialect::TensorDistAttribute::name ()) {
408+ VLOG (8 ) << " Parse TensorDistAttribute ." ;
409+ return pir::deserializeTensorDistAttr (attr_json, ctx);
410+ } else if (attr_name == paddle::dialect::OperationDistAttribute::name ()) {
411+ VLOG (8 ) << " Parse OperationDistAttribute ." ;
412+ return pir::deserializeOperationDistAttr (attr_json, ctx);
413+ } else {
414+ PADDLE_ENFORCE (
415+ false ,
416+ phi::errors::InvalidArgument (
417+ " Unknown Attr %s for parse paddle dist dialect attr" , attr_name));
418+ }
419+ return pir::Attribute ();
420+ }
421+
322422template <typename T>
323423T deserializeTypeFromJsonIncludeParseType (Json* type_json,
324424 pir::IrContext* ctx) {
@@ -430,6 +530,30 @@ deserializeTypeFromJsonIncludeParseType<paddle::dialect::SparseCsrTensorType>(
430530 non_zero_elements);
431531}
432532
533+ template <>
534+ paddle::dialect::DistDenseTensorType
535+ deserializeTypeFromJsonIncludeParseType<paddle::dialect::DistDenseTensorType>(
536+ Json* type_json, pir::IrContext* ctx) {
537+ Json data_json = type_json->at (DATA);
538+
539+ // deserialize pir::DenseTensorType dense_tensor_type;
540+ pir::DenseTensorType dense_tensor_type =
541+ deserializeTypeFromJsonIncludeParseType<pir::DenseTensorType>(
542+ &(data_json.at (0 )), ctx);
543+
544+ // deserialize TensorDistAttribute tensor_dist_attr;
545+ paddle::dialect::TensorDistAttribute tensor_dist_attr =
546+ deserializeTensorDistAttr (&(data_json.at (1 )), ctx);
547+
548+ // deserialize common::DDim local_ddim;
549+ std::vector<int64_t > dims =
550+ data_json.at (2 ).template get <std::vector<int64_t >>();
551+ phi::DDim local_ddim = phi::make_ddim (dims);
552+
553+ return paddle::dialect::DistDenseTensorType::get (
554+ ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
555+ }
556+
433557pir::Type AttrTypeReader::ReadBuiltInType (const std::string type_name,
434558 Json* type_json,
435559 pir::IrContext* ctx) {
@@ -516,4 +640,20 @@ pir::Type AttrTypeReader::ReadPaddleOperatorType(const std::string type_name,
516640 }
517641}
518642
643+ pir::Type AttrTypeReader::ReadPaddleDistType (const std::string type_name,
644+ Json* type_json,
645+ pir::IrContext* ctx) {
646+ if (type_name == paddle::dialect::DistDenseTensorType::name ()) {
647+ VLOG (8 ) << " Parse paddle::dialect::DistDenseTensorType ... " ;
648+ return pir::deserializeTypeFromJsonIncludeParseType<
649+ paddle::dialect::DistDenseTensorType>(type_json, ctx);
650+ } else {
651+ PADDLE_ENFORCE (false ,
652+ phi::errors::InvalidArgument (
653+ " Unknown Type %s for parse paddleoperator dialect type" ,
654+ type_name));
655+ return pir::Type ();
656+ }
657+ }
658+
519659} // namespace pir
0 commit comments