@@ -58,29 +58,10 @@ inline void SaveToMemory(const std::string& file_path,
5858}
5959
6060template <typename T, typename Context>
61- void SaveCombineTensorKernel (const Context& dev_ctx,
62- const std::vector<const phi::DenseTensor*>& x,
63- const std::string& file_path,
64- bool overwrite,
65- bool save_as_fp16,
66- bool save_to_memory,
67- phi::ExtendedTensor* out) {
68- std::string* y = nullptr ;
69- if (out != nullptr ) {
70- auto raw_out = static_cast <RawTensor*>(out);
71- y = raw_out->GetMutable <std::string>();
72- }
73-
74- bool is_present = FileExists (file_path);
75- if (is_present && !overwrite) {
76- PADDLE_THROW (common::errors::PreconditionNotMet (
77- " %s exists! Cannot save_combine to it when overwrite is set to "
78- " false." ,
79- file_path,
80- overwrite));
81- }
82-
83- std::ostringstream ss;
61+ void SerializeCombineTensor (const Context& dev_ctx,
62+ const std::vector<const phi::DenseTensor*>& x,
63+ bool save_as_fp16,
64+ std::ostream& ss) {
8465 PADDLE_ENFORCE_GT (x.size (),
8566 0UL ,
8667 common::errors::InvalidArgument (
@@ -114,8 +95,45 @@ void SaveCombineTensorKernel(const Context& dev_ctx,
11495 SerializeToStream (ss, tensor, dev_ctx);
11596 }
11697 }
98+ }
11799
118- SaveToMemory (file_path, ss, save_to_memory, y);
100+ template <typename T, typename Context>
101+ void SaveCombineTensorKernel (const Context& dev_ctx,
102+ const std::vector<const phi::DenseTensor*>& x,
103+ const std::string& file_path,
104+ bool overwrite,
105+ bool save_as_fp16,
106+ bool save_to_memory,
107+ phi::ExtendedTensor* out) {
108+ std::string* y = nullptr ;
109+ if (out != nullptr ) {
110+ auto raw_out = static_cast <RawTensor*>(out);
111+ y = raw_out->GetMutable <std::string>();
112+ }
113+
114+ bool is_present = FileExists (file_path);
115+ if (is_present && !overwrite) {
116+ PADDLE_THROW (common::errors::PreconditionNotMet (
117+ " %s exists! Cannot save_combine to it when overwrite is set to "
118+ " false." ,
119+ file_path,
120+ overwrite));
121+ }
122+
123+ if (save_to_memory) {
124+ std::ostringstream ss;
125+ SerializeCombineTensor<T>(dev_ctx, x, save_as_fp16, ss);
126+ SaveToMemory (file_path, ss, save_to_memory, y);
127+ } else {
128+ MkDirRecursively (DirName (file_path).c_str ());
129+ std::ofstream fout (file_path, std::ios::binary);
130+ PADDLE_ENFORCE_EQ (static_cast <bool >(fout),
131+ true ,
132+ common::errors::Unavailable (
133+ " Cannot open %s to save variables." , file_path));
134+ SerializeCombineTensor<T>(dev_ctx, x, save_as_fp16, fout);
135+ fout.close ();
136+ }
119137}
120138
121139template <typename T, typename Context>
0 commit comments