Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions paddle/fluid/pir/serialize_deserialize/src/save_load_parameters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ void SaveCombineFunction(const std::vector<const phi::DenseTensor*>& x,

MkDirRecursively(DirName(file_path).c_str());
VLOG(6) << "save func save path: " << file_path;
std::ostringstream ss;
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
common::errors::Unavailable(
"Cannot open %s to save variables.", file_path));
PADDLE_ENFORCE_GT(x.size(),
0UL,
common::errors::InvalidArgument(
Expand All @@ -134,18 +138,11 @@ void SaveCombineFunction(const std::vector<const phi::DenseTensor*>& x,
auto out_dtype = save_as_fp16 ? phi::DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
auto out = CastTensorType(dev_ctx, tensor, out_dtype);
paddle::framework::SerializeToStream(ss, out, *dev_ctx);
paddle::framework::SerializeToStream(fout, out, *dev_ctx);
} else {
paddle::framework::SerializeToStream(ss, tensor, *dev_ctx);
paddle::framework::SerializeToStream(fout, tensor, *dev_ctx);
}
}
MkDirRecursively(DirName(file_path).c_str());
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
common::errors::Unavailable(
"Cannot open %s to save variables.", file_path));
fout << ss.str();
fout.close();
VLOG(6) << "save combine done ";
}
Expand Down
66 changes: 42 additions & 24 deletions paddle/phi/kernels/impl/save_combine_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,10 @@ inline void SaveToMemory(const std::string& file_path,
}

template <typename T, typename Context>
void SaveCombineTensorKernel(const Context& dev_ctx,
const std::vector<const phi::DenseTensor*>& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16,
bool save_to_memory,
phi::ExtendedTensor* out) {
std::string* y = nullptr;
if (out != nullptr) {
auto raw_out = static_cast<RawTensor*>(out);
y = raw_out->GetMutable<std::string>();
}

bool is_present = FileExists(file_path);
if (is_present && !overwrite) {
PADDLE_THROW(common::errors::PreconditionNotMet(
"%s exists! Cannot save_combine to it when overwrite is set to "
"false.",
file_path,
overwrite));
}

std::ostringstream ss;
void SerializeCombineTensor(const Context& dev_ctx,
const std::vector<const phi::DenseTensor*>& x,
bool save_as_fp16,
std::ostream& ss) {
PADDLE_ENFORCE_GT(x.size(),
0UL,
common::errors::InvalidArgument(
Expand Down Expand Up @@ -114,8 +95,45 @@ void SaveCombineTensorKernel(const Context& dev_ctx,
SerializeToStream(ss, tensor, dev_ctx);
}
}
}

SaveToMemory(file_path, ss, save_to_memory, y);
template <typename T, typename Context>
void SaveCombineTensorKernel(const Context& dev_ctx,
const std::vector<const phi::DenseTensor*>& x,
const std::string& file_path,
bool overwrite,
bool save_as_fp16,
bool save_to_memory,
phi::ExtendedTensor* out) {
std::string* y = nullptr;
if (out != nullptr) {
auto raw_out = static_cast<RawTensor*>(out);
y = raw_out->GetMutable<std::string>();
}

bool is_present = FileExists(file_path);
if (is_present && !overwrite) {
PADDLE_THROW(common::errors::PreconditionNotMet(
"%s exists! Cannot save_combine to it when overwrite is set to "
"false.",
file_path,
overwrite));
}

if (save_to_memory) {
std::ostringstream ss;
SerializeCombineTensor<T>(dev_ctx, x, save_as_fp16, ss);
SaveToMemory(file_path, ss, save_to_memory, y);
} else {
MkDirRecursively(DirName(file_path).c_str());
std::ofstream fout(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fout),
true,
common::errors::Unavailable(
"Cannot open %s to save variables.", file_path));
SerializeCombineTensor<T>(dev_ctx, x, save_as_fp16, fout);
fout.close();
}
}

template <typename T, typename Context>
Expand Down