Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/pir/serialize_deserialize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ include_directories(pir_save_load PRIVATE

link_directories(${PADDLE_BINARY_DIR}/third_party/install/yaml-cpp/lib/)
target_compile_definitions(yaml INTERFACE YAML_CPP_STATIC_DEFINE)
if(NOT WIN32 AND NOT APPLE)
if(LINUX)
link_libraries(stdc++fs)
endif()

add_definitions(-DPADDLE_ROOT="${PADDLE_SOURCE_DIR}")
add_definitions(
-DPATCH_PATH="../../../../../python/paddle/pir/serialize_deserialize/patch")

cc_library(
pir_save_load
Expand Down
6 changes: 1 addition & 5 deletions paddle/fluid/pir/serialize_deserialize/src/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ bool ReadModule(const std::string& file_path,
data.at(BASE_CODE).at(PIRVERSION).template get<uint64_t>();
if (file_version != (uint64_t)pir_version) {
builder.SetFileVersion(file_version);
const char* paddle_root = PADDLE_ROOT;
VLOG(8) << "Paddle path: " << paddle_root;
std::filesystem::path patch_path = std::filesystem::path(paddle_root) /
"paddle" / "fluid" / "pir" /
"serialize_deserialize" / "patch";
std::filesystem::path patch_path = std::filesystem::path(PATCH_PATH);
VLOG(8) << "Patch path: " << patch_path;
builder.BuildPatch(patch_path.string());
}
Expand Down
17 changes: 7 additions & 10 deletions paddle/fluid/pir/serialize_deserialize/src/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ std::string DialectIdMap::GetDecompressDialectId(const std::string& id) {

uint64_t GetPirVersion() {
VLOG(8) << "Get PIR Version: ";
const char* paddle_root = PADDLE_ROOT;
VLOG(8) << "Paddle path: " << paddle_root;
std::filesystem::path patch_path = std::filesystem::path(paddle_root) /
"paddle" / "fluid" / "pir" /
"serialize_deserialize" / "patch";
// const char* paddle_root = PADDLE_ROOT;
// VLOG(8) << "Paddle path: " << paddle_root;
// std::filesystem::path patch_path = std::filesystem::path(paddle_root) /
// "python" / "paddle" / "pir" /
// "serialize_deserialize" / "patch";
std::filesystem::path patch_path = std::filesystem::path(PATCH_PATH);
VLOG(8) << "Patch path: " << patch_path;
int version = 0;
for (auto& v : std::filesystem::directory_iterator(patch_path)) {
Expand All @@ -108,11 +109,7 @@ uint64_t GetPirVersion() {
return version;
}
uint64_t GetMaxReleasePirVersion() {
const char* paddle_root = PADDLE_ROOT;
VLOG(8) << "Paddle path: " << paddle_root;
std::filesystem::path patch_path = std::filesystem::path(paddle_root) /
"paddle" / "fluid" / "pir" /
"serialize_deserialize" / "patch";
std::filesystem::path patch_path = std::filesystem::path(PATCH_PATH);
VLOG(8) << "Patch path: " << patch_path;
int version = 0;
for (auto& v : std::filesystem::directory_iterator(patch_path)) {
Expand Down
28 changes: 14 additions & 14 deletions test/auto_parallel/hybrid_strategy/semi_auto_llama_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,20 @@ def run_dy2static(self, tmp_ckpt_path):
break

# check pir dist_model save&load
# paddle.enable_static()
# model_file_path = os.path.join(
# tmp_ckpt_path,
# "rank_" + str(paddle.distributed.get_rank()) + ".pd_dist_model",
# )
# paddle.save(
# dist_model._engine._pir_dist_main_progs["train"], model_file_path
# )
# loaded_model = paddle.load(model_file_path)
# self.check_program_equal(
# dist_model._engine._pir_dist_main_progs["train"], loaded_model
# )
# paddle.disable_static()
# paddle.distributed.barrier()
paddle.enable_static()
model_file_path = os.path.join(
tmp_ckpt_path,
"rank_" + str(paddle.distributed.get_rank()) + ".pd_dist_model",
)
paddle.save(
dist_model._engine._pir_dist_main_progs["train"], model_file_path
)
loaded_model = paddle.load(model_file_path)
self.check_program_equal(
dist_model._engine._pir_dist_main_progs["train"], loaded_model
)
paddle.disable_static()
paddle.distributed.barrier()

time.sleep(10)

Expand Down
3 changes: 3 additions & 0 deletions test/cpp/pir/serialize_deserialize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ paddle_test(test_builtin_parameter SRCS test_builtin_parameter.cc)
paddle_test(save_load_version_compat_test SRCS save_load_version_compat_test.cc
DEPS test_dialect)

copy_if_different(${CMAKE_CURRENT_SOURCE_DIR}/patch
${CMAKE_CURRENT_BINARY_DIR}/patch)

if(WITH_ONNXRUNTIME AND WIN32)
# Copy onnxruntime for some c++ test in Windows, since the test will
# be build only in CI, so suppose the generator in Windows is Ninja.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,7 @@ TEST(save_load_version_compat, op_patch_test) {
const uint64_t pir_version = 0;
pir::PatchBuilder builder(pir_version);
builder.SetFileVersion(1);
std::string current_path = std::filesystem::current_path().string();
std::string paddle_root = "";
// For coverage CI
if (current_path.find("Paddle") == std::string::npos) {
paddle_root = current_path.substr(0, current_path.find("build") + 5);
} else {
paddle_root = current_path.substr(0, current_path.find("Paddle") + 6);
}
VLOG(8) << "Paddle path: " << paddle_root;
std::filesystem::path patch_path =
std::filesystem::path(paddle_root.c_str()) / "test" / "cpp" / "pir" /
"serialize_deserialize" / "patch";
std::filesystem::path patch_path("/patch");
VLOG(8) << "Patch path: " << patch_path;
builder.BuildPatch(patch_path.string());
}