Skip to content

请求CastNumpy2Scalar支持fp16、bf16等常见数据类型 #48574

@ZHUI

Description

@ZHUI

需求描述 Feature Description

paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
VLOG(4) << "type_name: " << type_name;
if (type_name == "numpy.ndarray" && PySequence_Check(obj)) {
PyObject* item = nullptr;
item = PySequence_GetItem(obj, 0);
if (PyObject_CheckFloatOrToFloat(&item)) {
float value = static_cast<float>(PyFloat_AsDouble(item));
return paddle::experimental::Scalar(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) is numpy.ndarry, the inner elements "
"must be "
"numpy.float32/float64 now, but got %s",
op_type,
arg_pos + 1,
type_name)); // NOLINT
}
} else if (type_name == "numpy.float64") {
double value = CastPyArg2Double(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.float32") {
float value = CastPyArg2Float(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.int64") {
int64_t value = CastPyArg2Long(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.int32" || type_name == "numpy.intc") {
int value = CastPyArg2Int(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"numpy.float32/float64, numpy.int32/int64, but got %s",
op_type,
arg_pos + 1,
type_name)); // NOLINT
}
}

测试样例:

>>> paddle.to_tensor([1], dtype="float16") + numpy.array([1], dtype="float16")[0]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: (InvalidArgument) __add__(): argument (position 1) must be numpy.float32/float64, numpy.int32/int64, but got numpy.float16 (at /paddle/paddle/fluid/pybind/eager_utils.cc:1280)

替代实现 Alternatives

No response

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions