|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include "paddle/fluid/pybind/pir.h" |
| 16 | +#include "paddle/fluid/pybind/pir_utils.h" |
16 | 17 |
|
17 | 18 | #include <Python.h> |
18 | 19 | #include <algorithm> |
@@ -207,35 +208,6 @@ std::string GetValueInfo(Value v) { |
207 | 208 | return ss.str(); |
208 | 209 | } |
209 | 210 |
|
210 | | -phi::DataType GetTensorDtype(Type type) { |
211 | | - if (!type) { |
212 | | - PADDLE_THROW( |
213 | | - common::errors::InvalidArgument("The type of value is nullptr.")); |
214 | | - } |
215 | | - if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) { |
216 | | - return dialect::TransToPhiDataType(dense_tensor_type.dtype()); |
217 | | - } else if (auto sparse_coo_tensor_type = |
218 | | - type.dyn_cast<SparseCooTensorType>()) { |
219 | | - return dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype()); |
220 | | - } else if (auto sparse_csr_tensor_type = |
221 | | - type.dyn_cast<SparseCsrTensorType>()) { |
222 | | - return dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype()); |
223 | | - } else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) { |
224 | | - return dialect::TransToPhiDataType(select_rows.dtype()); |
225 | | - } else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) { |
226 | | - return dialect::TransToPhiDataType(dense_array.dtype()); |
227 | | - } else { |
228 | | - PADDLE_THROW(common::errors::InvalidArgument( |
229 | | - "Currently, we can only get phi::DataType from DenseTensorType and " |
230 | | - "SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or " |
231 | | - "SparseCsrTensorType.")); |
232 | | - } |
233 | | -} |
234 | | - |
235 | | -phi::DataType GetValueDtype(Value value) { |
236 | | - return GetTensorDtype(value.type()); |
237 | | -} |
238 | | - |
239 | 211 | py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) { |
240 | 212 | IrMapping mapper; |
241 | 213 | if (p_mapper == nullptr) { |
@@ -271,7 +243,7 @@ pir::Value AppendDataOp(pir::Block *block, |
271 | 243 | paddle::dialect::IntArrayAttribute::get( |
272 | 244 | ctx, phi::IntArray(phi::vectorize(GetValueDims(value))))}, |
273 | 245 | {"dtype", |
274 | | - paddle::dialect::DataTypeAttribute::get(ctx, GetValueDtype(value))}, |
| 246 | + paddle::dialect::DataTypeAttribute::get(ctx, pir::GetValueDtype(value))}, |
275 | 247 | {"place", PlaceAttribute::get(ctx, phi::Place())}}; |
276 | 248 | std::vector<pir::Type> output_types{value.type()}; |
277 | 249 | pir::Operation *operation = |
@@ -1369,7 +1341,7 @@ void BindValue(py::module *m) { |
1369 | 1341 | }) |
1370 | 1342 | .def_property( |
1371 | 1343 | "dtype", |
1372 | | - [](Value self) { return GetValueDtype(self); }, |
| 1344 | + [](Value self) { return pir::GetValueDtype(self); }, |
1373 | 1345 | [](Value self, phi::DataType dtype) { |
1374 | 1346 | PADDLE_THROW(common::errors::InvalidArgument( |
1375 | 1347 | "can't set dtype when building static graph")); |
@@ -2241,7 +2213,7 @@ static void inline CreateVariableIfNotExist( |
2241 | 2213 | phi::DeviceContextPool &pool = phi::DeviceContextPool::Instance(); |
2242 | 2214 | const phi::DeviceContext *dev_ctx = nullptr; |
2243 | 2215 | dev_ctx = pool.Get(exe->GetPlace()); |
2244 | | - dev_ctx->Alloc(tensor_temp, GetValueDtype(value)); |
| 2216 | + dev_ctx->Alloc(tensor_temp, pir::GetValueDtype(value)); |
2245 | 2217 | } |
2246 | 2218 | } |
2247 | 2219 | return; |
|
0 commit comments