|
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | 15 | #include "paddle/fluid/pybind/pir.h" |
| 16 | +#include "paddle/fluid/pybind/pir_utils.h" |
16 | 17 |
|
17 | 18 | #include <Python.h> |
18 | 19 | #include <algorithm> |
@@ -210,35 +211,6 @@ std::string GetValueInfo(Value v) { |
210 | 211 | return ss.str(); |
211 | 212 | } |
212 | 213 |
|
213 | | -phi::DataType GetTensorDtype(Type type) { |
214 | | - if (!type) { |
215 | | - PADDLE_THROW( |
216 | | - common::errors::InvalidArgument("The type of value is nullptr.")); |
217 | | - } |
218 | | - if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) { |
219 | | - return dialect::TransToPhiDataType(dense_tensor_type.dtype()); |
220 | | - } else if (auto sparse_coo_tensor_type = |
221 | | - type.dyn_cast<SparseCooTensorType>()) { |
222 | | - return dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype()); |
223 | | - } else if (auto sparse_csr_tensor_type = |
224 | | - type.dyn_cast<SparseCsrTensorType>()) { |
225 | | - return dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype()); |
226 | | - } else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) { |
227 | | - return dialect::TransToPhiDataType(select_rows.dtype()); |
228 | | - } else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) { |
229 | | - return dialect::TransToPhiDataType(dense_array.dtype()); |
230 | | - } else { |
231 | | - PADDLE_THROW(common::errors::InvalidArgument( |
232 | | - "Currently, we can only get phi::DataType from DenseTensorType and " |
233 | | - "SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or " |
234 | | - "SparseCsrTensorType.")); |
235 | | - } |
236 | | -} |
237 | | - |
238 | | -phi::DataType GetValueDtype(Value value) { |
239 | | - return GetTensorDtype(value.type()); |
240 | | -} |
241 | | - |
242 | 214 | py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) { |
243 | 215 | IrMapping mapper; |
244 | 216 | if (p_mapper == nullptr) { |
@@ -274,7 +246,7 @@ pir::Value AppendDataOp(pir::Block *block, |
274 | 246 | paddle::dialect::IntArrayAttribute::get( |
275 | 247 | ctx, phi::IntArray(phi::vectorize(GetValueDims(value))))}, |
276 | 248 | {"dtype", |
277 | | - paddle::dialect::DataTypeAttribute::get(ctx, GetValueDtype(value))}, |
| 249 | + paddle::dialect::DataTypeAttribute::get(ctx, pir::GetValueDtype(value))}, |
278 | 250 | {"place", PlaceAttribute::get(ctx, phi::Place())}}; |
279 | 251 | std::vector<pir::Type> output_types{value.type()}; |
280 | 252 | pir::Operation *operation = |
@@ -1427,7 +1399,7 @@ void BindValue(py::module *m) { |
1427 | 1399 | }) |
1428 | 1400 | .def_property( |
1429 | 1401 | "dtype", |
1430 | | - [](Value self) { return GetValueDtype(self); }, |
| 1402 | + [](Value self) { return pir::GetValueDtype(self); }, |
1431 | 1403 | [](Value self, phi::DataType dtype) { |
1432 | 1404 | PADDLE_THROW(common::errors::InvalidArgument( |
1433 | 1405 | "can't set dtype when building static graph")); |
@@ -2299,7 +2271,7 @@ static void inline CreateVariableIfNotExist( |
2299 | 2271 | phi::DeviceContextPool &pool = phi::DeviceContextPool::Instance(); |
2300 | 2272 | const phi::DeviceContext *dev_ctx = nullptr; |
2301 | 2273 | dev_ctx = pool.Get(exe->GetPlace()); |
2302 | | - dev_ctx->Alloc(tensor_temp, GetValueDtype(value)); |
| 2274 | + dev_ctx->Alloc(tensor_temp, pir::GetValueDtype(value)); |
2303 | 2275 | } |
2304 | 2276 | } |
2305 | 2277 | return; |
|
0 commit comments