-
Couldn't load subscription status.
- Fork 5.9k
【PTen】Add Scalar and ScalarArray in pten #37409
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a6218c4
4df29cf
81fb5d4
e213512
72aa162
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,69 +18,217 @@ limitations under the License. */ | |
| #include <limits> | ||
|
|
||
| #include "paddle/pten/api/ext/exception.h" | ||
|
|
||
| #include "paddle/pten/api/include/tensor.h" | ||
| namespace paddle { | ||
| namespace experimental { | ||
|
|
||
| class Scalar { | ||
| template <typename T> | ||
| class ScalarBase { | ||
| public: | ||
| // Constructor support implicit | ||
| Scalar(float val) : tag(Tag::HAS_F) { data_.f = val; } // NOLINT | ||
| ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT | ||
| data_.f64 = val; | ||
| } | ||
|
|
||
| ScalarBase(float val) : dtype_(DataType::FLOAT32) { // NOLINT | ||
| data_.f32 = val; | ||
| } | ||
|
|
||
| ScalarBase(float16 val) : dtype_(DataType::FLOAT16) { // NOLINT | ||
| data_.f16 = val; | ||
| } | ||
|
|
||
| Scalar(double val) : tag(Tag::HAS_D) { data_.d = val; } // NOLINT | ||
| ScalarBase(bfloat16 val) : dtype_(DataType::BFLOAT16) { // NOLINT | ||
| data_.bf16 = val; | ||
| } | ||
|
|
||
| Scalar(int32_t val) : tag(Tag::HAS_I32) { data_.i32 = val; } // NOLINT | ||
| ScalarBase(int64_t val) : dtype_(DataType::INT64) { // NOLINT | ||
| data_.i64 = val; | ||
| } | ||
|
|
||
| Scalar(int64_t val) : tag(Tag::HAS_I64) { data_.i64 = val; } // NOLINT | ||
| ScalarBase(int32_t val) : dtype_(DataType::INT32) { // NOLINT | ||
| data_.i32 = val; | ||
| } | ||
|
|
||
| Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT | ||
| ScalarBase(int16_t val) : dtype_(DataType::INT16) { // NOLINT | ||
| data_.i16 = val; | ||
| } | ||
|
|
||
| Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT | ||
| ScalarBase(int8_t val) : dtype_(DataType::INT8) { // NOLINT | ||
| data_.i8 = val; | ||
| } | ||
|
|
||
| ScalarBase(uint64_t val) : dtype_(DataType::UINT64) { // NOLINT | ||
| data_.ui64 = val; | ||
| } | ||
|
|
||
| ScalarBase(uint32_t val) : dtype_(DataType::UINT32) { // NOLINT | ||
| data_.ui32 = val; | ||
| } | ||
|
|
||
| ScalarBase(uint16_t val) : dtype_(DataType::UINT16) { // NOLINT | ||
| data_.ui16 = val; | ||
| } | ||
|
|
||
| ScalarBase(uint8_t val) : dtype_(DataType::UINT8) { // NOLINT | ||
| data_.ui8 = val; | ||
| } | ||
|
|
||
| ScalarBase(bool val) : dtype_(DataType::BOOL) { // NOLINT | ||
| data_.b = val; | ||
| } | ||
|
|
||
| ScalarBase(complex64 val) : dtype_(DataType::COMPLEX64) { // NOLINT | ||
| data_.c64 = val; | ||
| } | ||
|
|
||
| ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT | ||
| data_.c128 = val; | ||
| } | ||
|
|
||
| // The compatible method for fliud operators, | ||
| // and it will be removed in the future. | ||
| explicit ScalarBase(const std::string& str_value) | ||
| : dtype_(DataType::FLOAT64) { | ||
| if (str_value == "inf") { | ||
| data_.d = std::numeric_limits<double>::infinity(); | ||
| data_.f64 = std::numeric_limits<double>::infinity(); | ||
| } else if (str_value == "-inf") { | ||
| data_.d = -std::numeric_limits<double>::infinity(); | ||
| data_.f64 = -std::numeric_limits<double>::infinity(); | ||
| } else if (str_value == "nan") { | ||
| data_.d = std::numeric_limits<double>::quiet_NaN(); | ||
| data_.f64 = std::numeric_limits<double>::quiet_NaN(); | ||
| } else { | ||
| data_.d = std::stod(str_value); | ||
| data_.f64 = std::stod(str_value); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里str_value 会是非法字符串吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的str_value是为了适配原来fill_constant op的参数,目前不会出现非法字符串的情况,完成op的迁移后这个接口应该会去掉,已补充注释 |
||
| } | ||
| } | ||
|
|
||
| // The Tensor must have one dim | ||
| ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT | ||
| PD_CHECK( | ||
| tensor.numel() == 1, | ||
| "The Scalar only supports Tensor with 1 element, but now Tensor has `", | ||
| tensor.numel(), | ||
| "` element."); | ||
| switch (dtype_) { | ||
| case DataType::FLOAT32: | ||
| data_.f32 = tensor.template data<float>()[0]; | ||
| break; | ||
| case DataType::FLOAT64: | ||
| data_.f64 = tensor.template data<double>()[0]; | ||
| break; | ||
| case DataType::FLOAT16: | ||
| data_.f16 = tensor.template data<float16>()[0]; | ||
| break; | ||
| case DataType::BFLOAT16: | ||
| data_.bf16 = tensor.template data<bfloat16>()[0]; | ||
| break; | ||
| case DataType::INT32: | ||
| data_.i32 = tensor.template data<int32_t>()[0]; | ||
| break; | ||
| case DataType::INT64: | ||
| data_.i64 = tensor.template data<int64_t>()[0]; | ||
| break; | ||
| case DataType::INT16: | ||
| data_.i16 = tensor.template data<int16_t>()[0]; | ||
| break; | ||
| case DataType::INT8: | ||
| data_.i8 = tensor.template data<int8_t>()[0]; | ||
| break; | ||
| case DataType::UINT16: | ||
| data_.ui16 = tensor.template data<uint16_t>()[0]; | ||
| break; | ||
| case DataType::UINT8: | ||
| data_.ui8 = tensor.template data<uint8_t>()[0]; | ||
| break; | ||
| case DataType::BOOL: | ||
| data_.b = tensor.template data<bool>()[0]; | ||
| break; | ||
| case DataType::COMPLEX64: | ||
| data_.c64 = tensor.template data<complex64>()[0]; | ||
| break; | ||
| case DataType::COMPLEX128: | ||
| data_.c128 = tensor.template data<complex128>()[0]; | ||
| break; | ||
| default: | ||
| PD_THROW("Invalid tensor data type `", dtype_, "`."); | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| inline T to() const { | ||
| switch (tag) { | ||
| case Tag::HAS_F: | ||
| return static_cast<T>(data_.f); | ||
| case Tag::HAS_D: | ||
| return static_cast<T>(data_.d); | ||
| case Tag::HAS_I32: | ||
| return static_cast<T>(data_.i32); | ||
| case Tag::HAS_I64: | ||
| return static_cast<T>(data_.i64); | ||
| case Tag::HAS_B: | ||
| return static_cast<T>(data_.b); | ||
| template <typename OtherT> | ||
| ScalarBase(const ScalarBase<OtherT>& other) { | ||
| CopyScalar(other, this); | ||
| } | ||
|
|
||
| template <typename RT> | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可能有类型安全问题 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [TODO] 在后续方案的优化迭代中进行完善 |
||
| inline RT to() const { | ||
| switch (dtype_) { | ||
| case DataType::FLOAT32: | ||
| return static_cast<RT>(data_.f32); | ||
| case DataType::FLOAT64: | ||
| return static_cast<RT>(data_.f64); | ||
| case DataType::FLOAT16: | ||
| return static_cast<RT>(data_.f16); | ||
| case DataType::BFLOAT16: | ||
| return static_cast<RT>(data_.bf16); | ||
| case DataType::INT32: | ||
| return static_cast<RT>(data_.i32); | ||
| case DataType::INT64: | ||
| return static_cast<RT>(data_.i64); | ||
| case DataType::INT16: | ||
| return static_cast<RT>(data_.i16); | ||
| case DataType::INT8: | ||
| return static_cast<RT>(data_.i8); | ||
| case DataType::UINT16: | ||
| return static_cast<RT>(data_.ui16); | ||
| case DataType::UINT8: | ||
| return static_cast<RT>(data_.ui8); | ||
| case DataType::BOOL: | ||
| return static_cast<RT>(data_.b); | ||
| case DataType::COMPLEX64: | ||
| return static_cast<RT>(data_.c64); | ||
| case DataType::COMPLEX128: | ||
| return static_cast<RT>(data_.c128); | ||
| default: | ||
| PD_THROW("Invalid enum scalar type tag `", static_cast<int>(tag), "`."); | ||
| PD_THROW("Invalid enum scalar data type `", dtype_, "`."); | ||
| } | ||
| } | ||
|
|
||
| private: | ||
| enum class Tag { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B }; | ||
| Tag tag; | ||
| template <typename T1, typename T2> | ||
| friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst); | ||
|
|
||
| private: | ||
| DataType dtype_; | ||
| union data { | ||
| float f; | ||
| double d; | ||
| bool b; | ||
| int8_t i8; | ||
| int16_t i16; | ||
| int32_t i32; | ||
| int64_t i64; | ||
| bool b; | ||
| uint8_t ui8; | ||
| uint16_t ui16; | ||
| uint32_t ui32; | ||
| uint64_t ui64; | ||
| bfloat16 bf16; | ||
| float16 f16; | ||
| float f32; | ||
| double f64; | ||
| complex64 c64; | ||
| complex128 c128; | ||
| } data_; | ||
| }; | ||
|
|
||
| template <typename T1, typename T2> | ||
| void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) { | ||
| dst->dtype_ = src.dtype_; | ||
| dst->data_.c128 = src.data_.c128; | ||
| } | ||
|
|
||
| using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>; | ||
|
|
||
| } // namespace experimental | ||
| } // namespace paddle | ||
|
|
||
| namespace pten { | ||
| using Scalar = paddle::experimental::Scalar; | ||
| class DenseTensor; | ||
| using Scalar = paddle::experimental::ScalarBase<DenseTensor>; | ||
| } // namespace pten | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了与之前的设计统一,推理希望优先使用继承而非模板,如果限于排期可暂时放松设计,后续再进行修改或给出文档说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前设计仅作为解决参数问题的临时方案,后续可在解决问题的前提下对方案进行调整完善或者重新设计