Skip to content

Commit 28137ec

Browse files
[PIR save/load]Fix save combine memory (#69683)
* fix save combine memory * fix
1 parent 0780ed1 commit 28137ec

File tree

2 files changed

+49
-34
lines changed

2 files changed

+49
-34
lines changed

paddle/fluid/pir/serialize_deserialize/src/save_load_parameters.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ void SaveCombineFunction(const std::vector<const phi::DenseTensor*>& x,
115115

116116
MkDirRecursively(DirName(file_path).c_str());
117117
VLOG(6) << "save func save path: " << file_path;
118-
std::ostringstream ss;
118+
std::ofstream fout(file_path, std::ios::binary);
119+
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
120+
true,
121+
common::errors::Unavailable(
122+
"Cannot open %s to save variables.", file_path));
119123
PADDLE_ENFORCE_GT(x.size(),
120124
0UL,
121125
common::errors::InvalidArgument(
@@ -134,18 +138,11 @@ void SaveCombineFunction(const std::vector<const phi::DenseTensor*>& x,
134138
auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype;
135139
if (in_dtype != out_dtype) {
136140
auto out = CastTensorType(dev_ctx, tensor, out_dtype);
137-
paddle::framework::SerializeToStream(ss, out, *dev_ctx);
141+
paddle::framework::SerializeToStream(fout, out, *dev_ctx);
138142
} else {
139-
paddle::framework::SerializeToStream(ss, tensor, *dev_ctx);
143+
paddle::framework::SerializeToStream(fout, tensor, *dev_ctx);
140144
}
141145
}
142-
MkDirRecursively(DirName(file_path).c_str());
143-
std::ofstream fout(file_path, std::ios::binary);
144-
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
145-
true,
146-
common::errors::Unavailable(
147-
"Cannot open %s to save variables.", file_path));
148-
fout << ss.str();
149146
fout.close();
150147
VLOG(6) << "save combine done ";
151148
}

paddle/phi/kernels/impl/save_combine_kernel_impl.h

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -58,29 +58,10 @@ inline void SaveToMemory(const std::string& file_path,
5858
}
5959

6060
template <typename T, typename Context>
61-
void SaveCombineTensorKernel(const Context& dev_ctx,
62-
const std::vector<const phi::DenseTensor*>& x,
63-
const std::string& file_path,
64-
bool overwrite,
65-
bool save_as_fp16,
66-
bool save_to_memory,
67-
phi::ExtendedTensor* out) {
68-
std::string* y = nullptr;
69-
if (out != nullptr) {
70-
auto raw_out = static_cast<RawTensor*>(out);
71-
y = raw_out->GetMutable<std::string>();
72-
}
73-
74-
bool is_present = FileExists(file_path);
75-
if (is_present && !overwrite) {
76-
PADDLE_THROW(common::errors::PreconditionNotMet(
77-
"%s exists! Cannot save_combine to it when overwrite is set to "
78-
"false.",
79-
file_path,
80-
overwrite));
81-
}
82-
83-
std::ostringstream ss;
61+
void SerializeCombineTensor(const Context& dev_ctx,
62+
const std::vector<const phi::DenseTensor*>& x,
63+
bool save_as_fp16,
64+
std::ostream& ss) {
8465
PADDLE_ENFORCE_GT(x.size(),
8566
0UL,
8667
common::errors::InvalidArgument(
@@ -114,8 +95,45 @@ void SaveCombineTensorKernel(const Context& dev_ctx,
11495
SerializeToStream(ss, tensor, dev_ctx);
11596
}
11697
}
98+
}
11799

118-
SaveToMemory(file_path, ss, save_to_memory, y);
100+
template <typename T, typename Context>
101+
void SaveCombineTensorKernel(const Context& dev_ctx,
102+
const std::vector<const phi::DenseTensor*>& x,
103+
const std::string& file_path,
104+
bool overwrite,
105+
bool save_as_fp16,
106+
bool save_to_memory,
107+
phi::ExtendedTensor* out) {
108+
std::string* y = nullptr;
109+
if (out != nullptr) {
110+
auto raw_out = static_cast<RawTensor*>(out);
111+
y = raw_out->GetMutable<std::string>();
112+
}
113+
114+
bool is_present = FileExists(file_path);
115+
if (is_present && !overwrite) {
116+
PADDLE_THROW(common::errors::PreconditionNotMet(
117+
"%s exists! Cannot save_combine to it when overwrite is set to "
118+
"false.",
119+
file_path,
120+
overwrite));
121+
}
122+
123+
if (save_to_memory) {
124+
std::ostringstream ss;
125+
SerializeCombineTensor<T>(dev_ctx, x, save_as_fp16, ss);
126+
SaveToMemory(file_path, ss, save_to_memory, y);
127+
} else {
128+
MkDirRecursively(DirName(file_path).c_str());
129+
std::ofstream fout(file_path, std::ios::binary);
130+
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
131+
true,
132+
common::errors::Unavailable(
133+
"Cannot open %s to save variables.", file_path));
134+
SerializeCombineTensor<T>(dev_ctx, x, save_as_fp16, fout);
135+
fout.close();
136+
}
119137
}
120138

121139
template <typename T, typename Context>

0 commit comments

Comments
 (0)