Skip to content

Commit 1c89f08

Browse files
[PIR]add patch header template (#68481)
* add patch header template * fix ci
1 parent be933c9 commit 1c89f08

File tree

14 files changed

+178
-141
lines changed

14 files changed

+178
-141
lines changed

paddle/fluid/pir/serialize_deserialize/CMakeLists.txt

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,28 @@ if(LINUX)
1111
link_libraries(stdc++fs)
1212
endif()
1313

14-
add_definitions(-DPADDLE_ROOT="${PADDLE_SOURCE_DIR}")
15-
add_definitions(
16-
-DPATCH_PATH="../../../../../python/paddle/pir/serialize_deserialize/patch")
14+
file(GLOB_RECURSE YAML_PATCH_FILES "*.yaml")
15+
# change pir version when new patches are added
16+
add_definitions(-DDEVELOP_VERSION=1)
17+
add_definitions(-DRELEASE_VERSION=1)
18+
set(TEMPLATE_FILE ${CMAKE_CURRENT_SOURCE_DIR}/patch/template.h.in)
19+
set(PATCH_HEADER ${CMAKE_CURRENT_BINARY_DIR}/patch/patch.h)
20+
21+
configure_file(${TEMPLATE_FILE} ${PATCH_HEADER} @ONLY)
22+
file(WRITE "${PATCH_HEADER}"
23+
"#include <map>\n#include <string>\n\n"
24+
"const std::map<std::string, std::string> yaml_files = {\n")
25+
26+
foreach(PATCH_FILE ${YAML_PATCH_FILES})
27+
get_filename_component(FILENAME "${PATCH_FILE}" NAME_WE)
28+
file(READ ${PATCH_FILE} FILE_CONTENT)
29+
set(CONTENT "R\"(${FILE_CONTENT})\"")
30+
file(APPEND "${PATCH_HEADER}" "{ \"${FILENAME}\", ${CONTENT} },\n")
31+
endforeach()
32+
33+
file(APPEND "${PATCH_HEADER}" "};\n")
1734

1835
cc_library(
1936
pir_save_load
20-
SRCS ${SERIALIZE_DESERIALIZE_CPP_SOURCES}
37+
SRCS ${SERIALIZE_DESERIALIZE_CPP_SOURCES} ${PATCH_HEADER}
2138
DEPS op_dialect phi json yaml)

paddle/fluid/pir/serialize_deserialize/include/patch_util.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ Json ParseAttrPatches(const YAML::Node &root);
3535

3636
Json ParseTypePatches(const YAML::Node &root);
3737

38-
Json YamlParser(const std::string &yaml_file);
38+
/* Yaml file is set to be empty by default. It's only used for testing. */
39+
Json YamlParser(const std::string &version, const std::string &yaml_file = "");
3940

4041
} // namespace pir

paddle/fluid/pir/serialize_deserialize/include/schema.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,17 @@ namespace pir {
8282
#define NULL_TYPE "NULL"
8383

8484
// special op compress
85-
8685
#define PARAMETEROP "p"
8786

87+
// actions for patch
88+
#define DELETE "DEL"
89+
#define ADD "ADD"
90+
#define UPDATE "UPD"
91+
#define NEW_NAME "NN"
92+
#define ADD_ATTRS "ADD_A"
93+
#define ADD_OPRESULTS_ATTRS "ADD_OA"
94+
#define PATCH "patch"
95+
8896
std::pair<std::string, std::string> GetContentSplitByDot(
8997
const std::string& str);
9098

@@ -109,6 +117,4 @@ class DialectIdMap {
109117
std::unordered_map<std::string, std::string> DecompressDialect;
110118
};
111119

112-
uint64_t GetPirVersion();
113-
uint64_t GetMaxReleasePirVersion();
114120
} // namespace pir

paddle/fluid/pir/serialize_deserialize/include/version_compat.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ class PatchBuilder {
3434
PatchBuilder& operator=(const PatchBuilder&) = delete;
3535
PatchBuilder& operator=(PatchBuilder&&);
3636

37-
void IR_API BuildPatch(const std::string& path,
38-
uint64_t pir_version,
39-
uint64_t max_version);
37+
/* Patch patch is set to empty by default. It is only used for testing.
38+
*/
39+
void IR_API BuildPatch(uint64_t pir_version,
40+
uint64_t max_version,
41+
const std::string& path = "");
4042
/* If file_version != pir_vefrsion, set file_version for finding patch yamls.
4143
*/
4244
void SetFileVersion(const uint64_t version) { file_version_ = version; }

python/paddle/pir/serialize_deserialize/patch/Readme.md renamed to paddle/fluid/pir/serialize_deserialize/patch/Readme.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,24 @@ type_patches:
140140
```
141141
142142
## pir_version 配置说明
143-
### C++端版本号管理
144-
- 版本号管理在C++端,通过宏PIR_VERSION进行管理。
145-
- pir_version 定义PIR的版本迭代,每次PIR进行更新并新增patch文件后,版本号会顺序递增。与Paddle的主版本号解耦,可以独立迭代。
146-
- 定义GetPirVersion函数获取当前的版本号:在"paddle/fluid/pir/serialize_deserialize/patch"路径下进行yaml文件查询,如果存在"0.yaml"则为develop版本,pir_verison为0;否则查找到的yaml文件名最大值即为当前的pir_version。
147-
- ReadModule和WriteModule参数中的pir_version设为默认值,可以不用传递。pir_version 函数默认值为0,为develop版本下的值,进入函数后会获取当前的版本号。
143+
### C++端版本号管理与CMake配置
144+
- 版本号管理在C++端,在CMakeList.txt中配置。
145+
- PIR版本号定义PIR的版本迭代,版本号与yaml文件名强相关。每次PIR进行更新并新增patch文件后,patch文件名顺序递增,版本号同时顺序递增。与Paddle的主版本号解耦,可以独立迭代。
146+
```cmake
147+
# change pir version when new patches are added
148+
add_definitions(-DDEVELOP_VERSION=1)
149+
add_definitions(-DRELEASE_VERSION=1)
150+
```
151+
152+
```tree
153+
├─patch
154+
│ ├─0.yaml
155+
│ └─1.yaml
156+
```
157+
- RELEASE_VERSION 为已发布的版本中PIR版本号,即为patch yaml文件名的最大值。
158+
- DEVELOP_VERSION 为当前develop分支下的PIR版本号,若存在未发布的新增patch,配置在`0.yaml`中,且当前的develop pir 版本号为0。
159+
160+
- ReadModule和WriteModule参数中的pir_version设为默认值,可以不用传递。pir_version 函数默认值为-1,进入函数后会获取CMake中配置的当前的PIR版本号。
148161

149162
### Python端
150163
- Paddle的主版本号定义在Python端,与PIR version不产生关联。Python端不再需要获取和传入pir_version,直接使用默认值即可。
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
#include <map>
3+
#include <string>
4+
5+
const std::map<std::string, std::string> yaml_files = {
6+
@FILE_CONTENTS@
7+
};

paddle/fluid/pir/serialize_deserialize/src/interface.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ bool ReadModule(const std::string& file_path,
7272
std::ifstream f(file_path);
7373
Json data = Json::parse(f);
7474
if (pir_version < 0) {
75-
pir_version = GetPirVersion();
75+
pir_version = DEVELOP_VERSION;
7676
VLOG(6) << "pir_version is null, get pir_version: " << pir_version;
7777
}
7878

@@ -85,17 +85,15 @@ bool ReadModule(const std::string& file_path,
8585
if (file_version != (uint64_t)pir_version) {
8686
builder.SetFileVersion(file_version);
8787
// Set max_version to the max version number of release pir plus 1.
88-
auto max_version = GetMaxReleasePirVersion() + 1;
88+
auto max_version = RELEASE_VERSION + 1;
8989
// If pir_version_ is not 0, we will build patch from file_version_ to
9090
// pir_version_; If pir_version_ is 0, we will first build patch from
9191
// file_version_ to max_version, and then add 0.yaml to the end.
9292
auto version = pir_version == 0 ? max_version : pir_version;
9393
VLOG(6) << "file_version: " << file_version
9494
<< ", pir_version: " << pir_version
9595
<< ", final_version: " << version;
96-
std::filesystem::path patch_path = std::filesystem::path(PATCH_PATH);
97-
VLOG(8) << "Patch path: " << patch_path;
98-
builder.BuildPatch(patch_path.string(), version, max_version);
96+
builder.BuildPatch(version, max_version);
9997
}
10098
} else {
10199
PADDLE_THROW(common::errors::InvalidArgument("Invalid model file."));

paddle/fluid/pir/serialize_deserialize/src/ir_deserialize.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,8 @@ pir::Operation* ProgramReader::ReadOp(Json* op_json) {
240240
VLOG(8) << op_name << " has been patched: " << *op_json;
241241
// Apply patch to op name
242242
// This happens when changing an op into another dialect
243-
if (op_patch.contains("NEW_NAME")) {
244-
std::string new_name =
245-
op_patch.at("NEW_NAME").template get<std::string>();
243+
if (op_patch.contains(NEW_NAME)) {
244+
std::string new_name = op_patch.at(NEW_NAME).template get<std::string>();
246245
VLOG(8) << "change op name from " << op_name << " to " << new_name;
247246
op_name = new_name;
248247
op_json->at(ID) = op_name;
@@ -334,8 +333,8 @@ pir::AttributeMap ProgramReader::ReadAttributesMap(
334333
const std::unordered_map<std::string, Json>& attr_patch) {
335334
pir::AttributeMap attributes;
336335
// Add new attribute from patch
337-
if (attr_patch.count("A_ADD")) {
338-
for (auto& attr_json : attr_patch.at("A_ADD")) {
336+
if (attr_patch.count(ADD_ATTRS)) {
337+
for (auto& attr_json : attr_patch.at(ADD_ATTRS)) {
339338
attrs_json->insert(attrs_json->end(), attr_json);
340339
}
341340
VLOG(8) << "attr has been added: " << *attrs_json;
@@ -358,8 +357,8 @@ pir::AttributeMap ProgramReader::ReadAttributesMap(
358357
}
359358
VLOG(6) << "Finish Read pir::AttributeMap.";
360359
// Add new opresult attribute from patch
361-
if (attr_patch.count("OA_ADD")) {
362-
for (auto& attr_json : attr_patch.at("OA_ADD")) {
360+
if (attr_patch.count(ADD_OPRESULTS_ATTRS)) {
361+
for (auto& attr_json : attr_patch.at(ADD_OPRESULTS_ATTRS)) {
363362
opresult_attrs_json->insert(opresult_attrs_json->end(), attr_json);
364363
}
365364
VLOG(8) << "opresult attr has been added: " << *opresult_attrs_json;

paddle/fluid/pir/serialize_deserialize/src/patch_util.cc

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <vector>
2121
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
2222
#include "paddle/fluid/pir/serialize_deserialize/include/schema.h"
23+
#include "paddle/fluid/pir/serialize_deserialize/patch/patch.h"
2324
#include "paddle/phi/common/data_type.h"
2425
#include "paddle/pir/include/core/builtin_attribute.h"
2526
#include "paddle/pir/include/core/builtin_type.h"
@@ -262,8 +263,8 @@ Json ParseOpPairPatches(const YAML::Node &root) {
262263
VLOG(8) << "Op_pair_name: " << name;
263264
j_patch["op_pair"].push_back(op_name);
264265
}
265-
j_patch["patch"] = Json::object();
266-
j_patch["patch"]["op_pair"] = j_patch["op_pair"];
266+
j_patch[PATCH] = Json::object();
267+
j_patch[PATCH]["op_pair"] = j_patch["op_pair"];
267268
// parse actions
268269
auto actions = node["actions"];
269270
for (size_t j = 0; j < actions.size(); j++) {
@@ -278,20 +279,20 @@ Json ParseOpPairPatches(const YAML::Node &root) {
278279
Json j_add_out;
279280
j_add_out[ID] = out_id;
280281
j_add_out[TYPE_TYPE] = BuildTypeJsonPatch(action);
281-
j_patch["patch"][OPRESULTS]["ADD"].push_back(j_add_out);
282+
j_patch[PATCH][OPRESULTS][ADD].push_back(j_add_out);
282283
Json j_add_in;
283284
j_add_in[ID] = in_id;
284-
j_patch["patch"][OPOPERANDS]["ADD"].push_back(j_add_in);
285+
j_patch[PATCH][OPOPERANDS][ADD].push_back(j_add_in);
285286
} else if (action_name == "delete_value") {
286287
VLOG(8) << "Patch for deleting values.";
287288
int out_id = action["object"][0].as<int>();
288289
int in_id = action["object"][1].as<int>();
289290
Json j_del_out;
290291
j_del_out[ID] = out_id;
291-
j_patch["patch"][OPRESULTS]["DELETE"].push_back(j_del_out);
292+
j_patch[PATCH][OPRESULTS][DELETE].push_back(j_del_out);
292293
Json j_del_in;
293294
j_del_in[ID] = in_id;
294-
j_patch["patch"][OPOPERANDS]["DELETE"].push_back(j_del_in);
295+
j_patch[PATCH][OPOPERANDS][DELETE].push_back(j_del_in);
295296
}
296297
}
297298
json_patch.push_back(j_patch);
@@ -314,7 +315,7 @@ Json ParseOpPatches(const YAML::Node &root) {
314315
VLOG(8) << "Parse patches for " << op_name;
315316
Json j_patch;
316317
j_patch["op_name"] = op_name;
317-
j_patch["patch"] = Json::object();
318+
j_patch[PATCH] = Json::object();
318319
// parse actions
319320
auto actions = node["actions"];
320321

@@ -335,10 +336,10 @@ Json ParseOpPatches(const YAML::Node &root) {
335336
j_attr[ATTR_TYPE] = BuildAttrJsonPatch(action);
336337
if (action_name == "add_attr") {
337338
Json j_add = Json::object();
338-
j_add["ADD"] = j_attr;
339-
j_patch["patch"][ATTRS].push_back(j_add);
339+
j_add[ADD] = j_attr;
340+
j_patch[PATCH][ATTRS].push_back(j_add);
340341
} else {
341-
j_patch["patch"][ATTRS].push_back(j_attr);
342+
j_patch[PATCH][ATTRS].push_back(j_attr);
342343
}
343344
} else if (action_name == "add_output_attr" ||
344345
action_name == "modify_output_attr" ||
@@ -350,10 +351,10 @@ Json ParseOpPatches(const YAML::Node &root) {
350351
j_attr[ATTR_TYPE] = BuildAttrJsonPatch(action);
351352
if (action_name == "add_output_attr") {
352353
Json j_add = Json::object();
353-
j_add["ADD"] = j_attr;
354-
j_patch["patch"][OPRESULTS_ATTRS].push_back(j_add);
354+
j_add[ADD] = j_attr;
355+
j_patch[PATCH][OPRESULTS_ATTRS].push_back(j_add);
355356
} else {
356-
j_patch["patch"][OPRESULTS_ATTRS].push_back(j_attr);
357+
j_patch[PATCH][OPRESULTS_ATTRS].push_back(j_attr);
357358
}
358359
} else if (action_name == "modify_attr_name" ||
359360
action_name == "modify_output_attr_name") {
@@ -362,35 +363,35 @@ Json ParseOpPatches(const YAML::Node &root) {
362363
std::string new_name = action["default"].as<std::string>();
363364
Json j_attr;
364365
j_attr[NAME] = old_name;
365-
j_attr["NEW_NAME"] = new_name;
366+
j_attr[NEW_NAME] = new_name;
366367
std::string col =
367368
action_name == "modify_attr_name" ? ATTRS : OPRESULTS_ATTRS;
368-
j_patch["patch"][col].push_back(j_attr);
369+
j_patch[PATCH][col].push_back(j_attr);
369370
} else if (action_name == "delete_input") {
370371
VLOG(8) << "Patch for delete_input";
371372
Json j_input;
372373
int op_id = action["object"].as<int>();
373374
j_input[ID] = op_id;
374-
j_patch["patch"][OPOPERANDS]["DELETE"].push_back(j_input);
375+
j_patch[PATCH][OPOPERANDS][DELETE].push_back(j_input);
375376
} else if (action_name == "add_output") {
376377
VLOG(8) << "Patch for add_output";
377378
Json j_output;
378379
int op_id = action["object"].as<int>();
379380
j_output[ID] = op_id;
380381
j_output[TYPE_TYPE] = BuildTypeJsonPatch(action);
381-
j_patch["patch"][OPRESULTS]["ADD"].push_back(j_output);
382+
j_patch[PATCH][OPRESULTS][ADD].push_back(j_output);
382383
} else if (action_name == "modify_output_type") {
383384
VLOG(8) << "Patch for modify_output_type";
384385
int op_id = action["object"].as<int>();
385386
Json j_type;
386387
j_type[ID] = op_id;
387388
j_type[TYPE_TYPE] = BuildTypeJsonPatch(action);
388-
j_patch["patch"][OPRESULTS]["UPDATE"].push_back(j_type);
389+
j_patch[PATCH][OPRESULTS][UPDATE].push_back(j_type);
389390
} else if (action_name == "modify_name") {
390391
VLOG(8) << "Patch for modify_name";
391392
std::string op_name = action["default"].as<std::string>();
392393
GetCompressOpName(&op_name);
393-
j_patch["patch"]["NEW_NAME"] = op_name;
394+
j_patch[PATCH][NEW_NAME] = op_name;
394395
}
395396
}
396397
json_patch.push_back(j_patch);
@@ -410,15 +411,15 @@ Json ParseTypePatches(const YAML::Node &root) {
410411
VLOG(8) << "Type name after compressing: " << type_name;
411412
Json j_patch;
412413
j_patch["type_name"] = type_name;
413-
j_patch["patch"] = Json::object();
414+
j_patch[PATCH] = Json::object();
414415
auto actions = node["actions"];
415416
for (size_t j = 0; j < actions.size(); j++) {
416417
YAML::Node action = actions[j];
417418
std::string action_name = action["action"].as<std::string>();
418419
if (action_name == "modify_name") {
419-
j_patch["patch"]["NEW_NAME"] = GetTypeName(action);
420+
j_patch[PATCH][NEW_NAME] = GetTypeName(action);
420421
} else if (action_name == "delete_type") {
421-
j_patch["patch"]["NEW_NAME"] = "";
422+
j_patch[PATCH][NEW_NAME] = "";
422423
}
423424
}
424425
json_patch.push_back(j_patch);
@@ -439,31 +440,36 @@ Json ParseAttrPatches(const YAML::Node &root) {
439440
VLOG(8) << attr_name;
440441
Json j_patch;
441442
j_patch["attr_name"] = attr_name;
442-
j_patch["patch"] = Json::object();
443+
j_patch[PATCH] = Json::object();
443444
auto actions = node["actions"];
444445
for (size_t j = 0; j < actions.size(); j++) {
445446
YAML::Node action = actions[j];
446447
std::string action_name = action["action"].as<std::string>();
447448
if (action_name == "modify_name") {
448-
j_patch["patch"]["NEW_NAME"] = GetAttrName(action);
449+
j_patch[PATCH][NEW_NAME] = GetAttrName(action);
449450
} else if (action_name == "delete_attr") {
450-
j_patch["patch"]["NEW_NAME"] = "";
451+
j_patch[PATCH][NEW_NAME] = "";
451452
}
452453
}
453454
json_patch.push_back(j_patch);
454455
}
455456
return json_patch;
456457
}
457458

458-
Json YamlParser(const std::string &yaml_file) {
459+
Json YamlParser(const std::string &version, const std::string &yaml_file) {
460+
YAML::Node root;
459461
std::ifstream fin;
460-
VLOG(8) << yaml_file;
461-
fin.open(yaml_file);
462-
if (!fin) {
463-
VLOG(8) << yaml_file << " is not fin and return empty. ";
464-
return Json::object();
462+
if (yaml_file.empty()) {
463+
root = YAML::Load(yaml_files.at(version));
464+
} else {
465+
VLOG(8) << yaml_file;
466+
fin.open(yaml_file);
467+
if (!fin) {
468+
VLOG(8) << yaml_file << " is not fin and return empty. ";
469+
return Json::object();
470+
}
471+
root = YAML::Load(fin);
465472
}
466-
YAML::Node root = YAML::Load(fin);
467473
Json json_patch;
468474
if (!root.IsDefined()) {
469475
VLOG(8) << "Not defined";
@@ -484,7 +490,9 @@ Json YamlParser(const std::string &yaml_file) {
484490
Yaml attr_patch = root["attr_patches"];
485491
json_patch["attr_patches"] = ParseAttrPatches(attr_patch);
486492
VLOG(8) << "Finish attr json_patch: " << json_patch;
487-
fin.close();
493+
if (fin) {
494+
fin.close();
495+
}
488496
return json_patch;
489497
}
490498
} // namespace pir

0 commit comments

Comments
 (0)