Skip to content

Commit 3fd43fb

Browse files
authored
[AutoParalle] Support save load for pir dist_program (#67045)
* add serialize * add serialize * add deserialize * refine parameter op * fix bug * fix bug * fix bug * fix bug * add ut * refine * refine * refine * refine * refine
1 parent 7eb71dd commit 3fd43fb

File tree

9 files changed

+389
-7
lines changed

9 files changed

+389
-7
lines changed

paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class ProcessMeshAttribute : public pir::AttrBase<ProcessMeshAttribute,
6161
const std::vector<int64_t>& shape,
6262
const std::vector<int64_t>& process_ids,
6363
const std::vector<std::string>& dim_names);
64+
65+
static std::string name() { return "a_process_mesh"; }
6466
};
6567

6668
class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
@@ -98,6 +100,8 @@ class TensorDistAttribute : public pir::AttrBase<TensorDistAttribute,
98100
dims_mapping,
99101
partial_status);
100102
}
103+
104+
static std::string name() { return "a_tensor_dist"; }
101105
};
102106

103107
class OperationDistAttribute : public pir::AttrBase<OperationDistAttribute,
@@ -128,6 +132,8 @@ class OperationDistAttribute : public pir::AttrBase<OperationDistAttribute,
128132
const std::vector<Attribute>& results) {
129133
return get(ctx, ProcessMeshAttribute::get(ctx, mesh), operands, results);
130134
}
135+
136+
static std::string name() { return "a_op_dist"; }
131137
};
132138

133139
} // namespace dialect

paddle/fluid/pir/dialect/distributed/ir/dist_type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class DistDenseTensorType
3636
using Base::Base;
3737
using LoD = pir::DenseTensorTypeStorage::LoD;
3838

39+
static std::string name() { return "t_dist_dtensor"; }
40+
3941
pir::DenseTensorType dense_tensor_type() const;
4042
TensorDistAttribute tensor_dist_attr() const;
4143
const common::DDim& global_ddim() const { return dense_tensor_type().dims(); }

paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,16 @@
1919

2020
#include "paddle/common/layout.h"
2121
#include "paddle/fluid/framework/data_layout.h"
22+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
23+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
2224
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
2325
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
2426
#include "paddle/fluid/pir/serialize_deserialize/include/schema.h"
2527
#include "paddle/fluid/pir/serialize_deserialize/include/third_party.h"
2628
#include "paddle/phi/common/data_type.h"
2729
#include "paddle/pir/include/core/builtin_attribute.h"
2830
#include "paddle/pir/include/core/builtin_type.h"
31+
#include "paddle/utils/flat_hash_map.h"
2932

3033
namespace pir {
3134
#define DECOMPRESS_DIALECT_ID(name) \
@@ -54,6 +57,14 @@ class AttrTypeReader {
5457
static pir::Type ReadPaddleOperatorType(const std::string type_name,
5558
Json* type_json,
5659
pir::IrContext* ctx);
60+
61+
static pir::Type ReadPaddleDistType(const std::string type_name,
62+
Json* type_json,
63+
pir::IrContext* ctx);
64+
65+
static pir::Attribute ReadPaddleDistAttr(const std::string attr_name,
66+
Json* attr_json,
67+
pir::IrContext* ctx);
5768
};
5869

5970
template <typename T>
@@ -180,13 +191,16 @@ pir::Type parseType(Json* type_json) {
180191
}
181192

182193
pir::IrContext* ctx = pir::IrContext::Instance();
183-
std::pair<std::string, std::string> name = getContentSplitByDot(type_name);
194+
std::pair<std::string, std::string> name = GetContentSplitByDot(type_name);
184195

185196
if (DECOMPRESS_DIALECT_ID(name.first) == pir::BuiltinDialect::name()) {
186197
return AttrTypeReader::ReadBuiltInType(name.second, type_json, ctx);
187198
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
188199
paddle::dialect::OperatorDialect::name()) {
189200
return AttrTypeReader::ReadPaddleOperatorType(name.second, type_json, ctx);
201+
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
202+
paddle::dialect::DistDialect::name()) {
203+
return AttrTypeReader::ReadPaddleDistType(name.second, type_json, ctx);
190204
} else {
191205
PADDLE_ENFORCE(
192206
false,
@@ -209,13 +223,16 @@ pir::TypeAttribute deserializeAttrFromJson<pir::TypeAttribute, pir::Type>(
209223
pir::Attribute parseAttr(Json* attr_json) {
210224
std::string attr_name = attr_json->at(ID).template get<std::string>();
211225
pir::IrContext* ctx = pir::IrContext::Instance();
212-
std::pair<std::string, std::string> name = getContentSplitByDot(attr_name);
226+
std::pair<std::string, std::string> name = GetContentSplitByDot(attr_name);
213227

214228
if (DECOMPRESS_DIALECT_ID(name.first) == pir::BuiltinDialect::name()) {
215229
return AttrTypeReader::ReadBuiltInAttr(name.second, attr_json, ctx);
216230
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
217231
paddle::dialect::OperatorDialect::name()) {
218232
return AttrTypeReader::ReadPaddleOperatorAttr(name.second, attr_json, ctx);
233+
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
234+
paddle::dialect::DistDialect::name()) {
235+
return AttrTypeReader::ReadPaddleDistAttr(name.second, attr_json, ctx);
219236
} else {
220237
PADDLE_ENFORCE(
221238
false,
@@ -228,6 +245,68 @@ pir::Attribute parseAttr(Json* attr_json) {
228245
return pir::Attribute();
229246
}
230247

248+
// ProcessMesh includes: std::vector<int64_t>& shape, std::vector<int64_t>&
249+
// process_ids, std::vector<std::string>& dim_names
250+
paddle::dialect::ProcessMeshAttribute deserializeProcessMeshAttr(
251+
Json* attr_json, pir::IrContext* ctx) {
252+
Json data_json = attr_json->at(DATA);
253+
VLOG(8) << "deserialize shape";
254+
std::vector<int64_t> shape =
255+
data_json.at(0).template get<std::vector<int64_t>>();
256+
VLOG(8) << "deserialize process_ids";
257+
std::vector<int64_t> process_ids =
258+
data_json.at(1).template get<std::vector<int64_t>>();
259+
VLOG(8) << "deserialize dim_names";
260+
std::vector<std::string> dim_names =
261+
data_json.at(2).template get<std::vector<std::string>>();
262+
return paddle::dialect::ProcessMeshAttribute::get(
263+
ctx, shape, process_ids, dim_names);
264+
}
265+
266+
// TensorDistAttribute includes: ProcessMeshAttribute mesh_attr,
267+
// std::vector<int64_t> dims_mapping, flat_hash_map<int64_t, phi::ReduceType>
268+
// partial_status;
269+
paddle::dialect::TensorDistAttribute deserializeTensorDistAttr(
270+
Json* attr_json, pir::IrContext* ctx) {
271+
Json data_json = attr_json->at(DATA);
272+
VLOG(8) << "deserialize ProcessMeshAttr";
273+
paddle::dialect::ProcessMeshAttribute mesh =
274+
deserializeProcessMeshAttr(&(data_json.at(0)), ctx);
275+
VLOG(8) << "deserialize dims_mapping";
276+
std::vector<int64_t> dims_mapping =
277+
data_json.at(1).template get<std::vector<int64_t>>();
278+
VLOG(8) << "deserialize partial_status";
279+
paddle::flat_hash_map<int64_t, phi::ReduceType> partial_status;
280+
Json map_json = data_json.at(2);
281+
for (const auto& item : map_json) {
282+
partial_status[item[0]] = static_cast<phi::ReduceType>(item[1]);
283+
}
284+
return paddle::dialect::TensorDistAttribute::get(
285+
ctx, mesh, dims_mapping, partial_status);
286+
}
287+
288+
// OperationDistAttribute includes: ProcessMeshAttribute mesh_attr,
289+
// std::vector<pir::Attribute> operands, std::vector<pir::Attribute> results;
290+
paddle::dialect::OperationDistAttribute deserializeOperationDistAttr(
291+
Json* attr_json, pir::IrContext* ctx) {
292+
Json data_json = attr_json->at(DATA);
293+
paddle::dialect::ProcessMeshAttribute mesh =
294+
deserializeProcessMeshAttr(&(data_json.at(0)), ctx);
295+
std::vector<Attribute> operands;
296+
Json operands_json = data_json.at(1);
297+
for (auto& item : operands_json) {
298+
operands.push_back(parseAttr(&item));
299+
}
300+
301+
std::vector<Attribute> results;
302+
Json results_json = data_json.at(2);
303+
for (auto& item : results_json) {
304+
results.push_back(parseAttr(&item));
305+
}
306+
return paddle::dialect::OperationDistAttribute::get(
307+
ctx, mesh, operands, results);
308+
}
309+
231310
pir::Attribute AttrTypeReader::ReadBuiltInAttr(const std::string attr_name,
232311
Json* attr_json,
233312
pir::IrContext* ctx) {
@@ -319,6 +398,27 @@ pir::Attribute AttrTypeReader::ReadPaddleOperatorAttr(
319398
return pir::Attribute();
320399
}
321400

401+
pir::Attribute AttrTypeReader::ReadPaddleDistAttr(const std::string attr_name,
402+
Json* attr_json,
403+
pir::IrContext* ctx) {
404+
if (attr_name == paddle::dialect::ProcessMeshAttribute::name()) {
405+
VLOG(8) << "Parse ProcessMeshAttribute .";
406+
return pir::deserializeProcessMeshAttr(attr_json, ctx);
407+
} else if (attr_name == paddle::dialect::TensorDistAttribute::name()) {
408+
VLOG(8) << "Parse TensorDistAttribute .";
409+
return pir::deserializeTensorDistAttr(attr_json, ctx);
410+
} else if (attr_name == paddle::dialect::OperationDistAttribute::name()) {
411+
VLOG(8) << "Parse OperationDistAttribute .";
412+
return pir::deserializeOperationDistAttr(attr_json, ctx);
413+
} else {
414+
PADDLE_ENFORCE(
415+
false,
416+
phi::errors::InvalidArgument(
417+
"Unknown Attr %s for parse paddle dist dialect attr", attr_name));
418+
}
419+
return pir::Attribute();
420+
}
421+
322422
template <typename T>
323423
T deserializeTypeFromJsonIncludeParseType(Json* type_json,
324424
pir::IrContext* ctx) {
@@ -430,6 +530,30 @@ deserializeTypeFromJsonIncludeParseType<paddle::dialect::SparseCsrTensorType>(
430530
non_zero_elements);
431531
}
432532

533+
template <>
534+
paddle::dialect::DistDenseTensorType
535+
deserializeTypeFromJsonIncludeParseType<paddle::dialect::DistDenseTensorType>(
536+
Json* type_json, pir::IrContext* ctx) {
537+
Json data_json = type_json->at(DATA);
538+
539+
// deserialize pir::DenseTensorType dense_tensor_type;
540+
pir::DenseTensorType dense_tensor_type =
541+
deserializeTypeFromJsonIncludeParseType<pir::DenseTensorType>(
542+
&(data_json.at(0)), ctx);
543+
544+
// deserialize TensorDistAttribute tensor_dist_attr;
545+
paddle::dialect::TensorDistAttribute tensor_dist_attr =
546+
deserializeTensorDistAttr(&(data_json.at(1)), ctx);
547+
548+
// deserialize common::DDim local_ddim;
549+
std::vector<int64_t> dims =
550+
data_json.at(2).template get<std::vector<int64_t>>();
551+
phi::DDim local_ddim = phi::make_ddim(dims);
552+
553+
return paddle::dialect::DistDenseTensorType::get(
554+
ctx, dense_tensor_type, tensor_dist_attr, local_ddim);
555+
}
556+
433557
pir::Type AttrTypeReader::ReadBuiltInType(const std::string type_name,
434558
Json* type_json,
435559
pir::IrContext* ctx) {
@@ -516,4 +640,20 @@ pir::Type AttrTypeReader::ReadPaddleOperatorType(const std::string type_name,
516640
}
517641
}
518642

643+
pir::Type AttrTypeReader::ReadPaddleDistType(const std::string type_name,
644+
Json* type_json,
645+
pir::IrContext* ctx) {
646+
if (type_name == paddle::dialect::DistDenseTensorType::name()) {
647+
VLOG(8) << "Parse paddle::dialect::DistDenseTensorType ... ";
648+
return pir::deserializeTypeFromJsonIncludeParseType<
649+
paddle::dialect::DistDenseTensorType>(type_json, ctx);
650+
} else {
651+
PADDLE_ENFORCE(false,
652+
phi::errors::InvalidArgument(
653+
"Unknown Type %s for parse paddleoperator dialect type",
654+
type_name));
655+
return pir::Type();
656+
}
657+
}
658+
519659
} // namespace pir

paddle/fluid/pir/serialize_deserialize/include/schema.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
#pragma once
1516
#include "glog/logging.h"
17+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h"
1618
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
1719
#include "paddle/pir/include/core/builtin_dialect.h"
1820
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
21+
1922
namespace pir {
2023
/**
2124
* IMPORTANT!!!
@@ -57,6 +60,7 @@ namespace pir {
5760
// which is json array with json object(NAME and ATTR_TYPE)
5861
#define ATTRS "A"
5962
#define OPRESULTS_ATTRS "OA"
63+
#define DIST_ATTRS "DA"
6064

6165
// value's key:
6266
// value's type which should be pir::Type's json object(ID or ID and DATA).
@@ -78,9 +82,11 @@ namespace pir {
7882

7983
#define PARAMETEROP "p"
8084

81-
std::pair<std::string, std::string> getContentSplitByDot(
85+
std::pair<std::string, std::string> GetContentSplitByDot(
8286
const std::string& str);
8387

88+
std::vector<std::string> GetOpDistAttr();
89+
8490
void GetCompressOpName(std::string* op_name);
8591

8692
void GetDecompressOpName(std::string* op_name);

0 commit comments

Comments
 (0)