-
Notifications
You must be signed in to change notification settings - Fork 5.8k
Open
Description
需求描述 Feature Description
Paddle/paddle/fluid/pybind/eager_utils.cc
Lines 1318 to 1359 in bcf7513
paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, | |
const std::string& op_type, | |
ssize_t arg_pos) { | |
PyTypeObject* type = obj->ob_type; | |
auto type_name = std::string(type->tp_name); | |
VLOG(4) << "type_name: " << type_name; | |
if (type_name == "numpy.ndarray" && PySequence_Check(obj)) { | |
PyObject* item = nullptr; | |
item = PySequence_GetItem(obj, 0); | |
if (PyObject_CheckFloatOrToFloat(&item)) { | |
float value = static_cast<float>(PyFloat_AsDouble(item)); | |
return paddle::experimental::Scalar(value); | |
} else { | |
PADDLE_THROW(platform::errors::InvalidArgument( | |
"%s(): argument (position %d) is numpy.ndarry, the inner elements " | |
"must be " | |
"numpy.float32/float64 now, but got %s", | |
op_type, | |
arg_pos + 1, | |
type_name)); // NOLINT | |
} | |
} else if (type_name == "numpy.float64") { | |
double value = CastPyArg2Double(obj, op_type, arg_pos); | |
return paddle::experimental::Scalar(value); | |
} else if (type_name == "numpy.float32") { | |
float value = CastPyArg2Float(obj, op_type, arg_pos); | |
return paddle::experimental::Scalar(value); | |
} else if (type_name == "numpy.int64") { | |
int64_t value = CastPyArg2Long(obj, op_type, arg_pos); | |
return paddle::experimental::Scalar(value); | |
} else if (type_name == "numpy.int32" || type_name == "numpy.intc") { | |
int value = CastPyArg2Int(obj, op_type, arg_pos); | |
return paddle::experimental::Scalar(value); | |
} else { | |
PADDLE_THROW(platform::errors::InvalidArgument( | |
"%s(): argument (position %d) must be " | |
"numpy.float32/float64, numpy.int32/int64, but got %s", | |
op_type, | |
arg_pos + 1, | |
type_name)); // NOLINT | |
} | |
} |
测试样例:
>>> paddle.to_tensor([1], dtype="float16") + numpy.array([1], dtype="float16")[0]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: (InvalidArgument) __add__(): argument (position 1) must be numpy.float32/float64, numpy.int32/int64, but got numpy.float16 (at /paddle/paddle/fluid/pybind/eager_utils.cc:1280)
替代实现 Alternatives
No response