@@ -18,69 +18,217 @@ limitations under the License. */
1818#include  < limits> 
1919
2020#include  " paddle/pten/api/ext/exception.h" 
21- 
21+ # include   " paddle/pten/api/include/tensor.h " 
2222namespace  paddle  {
2323namespace  experimental  {
2424
25- class  Scalar  {
25+ template  <typename  T>
26+ class  ScalarBase  {
2627 public: 
2728  //  Constructor support implicit
28-   Scalar (float  val) : tag(Tag::HAS_F) { data_.f  = val; }  //  NOLINT
29+   ScalarBase (double  val) : dtype_(DataType::FLOAT64) {  //  NOLINT
30+     data_.f64  = val;
31+   }
32+ 
33+   ScalarBase (float  val) : dtype_(DataType::FLOAT32) {  //  NOLINT
34+     data_.f32  = val;
35+   }
36+ 
37+   ScalarBase (float16 val) : dtype_(DataType::FLOAT16) {  //  NOLINT
38+     data_.f16  = val;
39+   }
2940
30-   Scalar (double  val) : tag(Tag::HAS_D) { data_.d  = val; }  //  NOLINT
41+   ScalarBase (bfloat16 val) : dtype_(DataType::BFLOAT16) {  //  NOLINT
42+     data_.bf16  = val;
43+   }
3144
32-   Scalar (int32_t  val) : tag(Tag::HAS_I32) { data_.i32  = val; }  //  NOLINT
45+   ScalarBase (int64_t  val) : dtype_(DataType::INT64) {  //  NOLINT
46+     data_.i64  = val;
47+   }
3348
34-   Scalar (int64_t  val) : tag(Tag::HAS_I64) { data_.i64  = val; }  //  NOLINT
49+   ScalarBase (int32_t  val) : dtype_(DataType::INT32) {  //  NOLINT
50+     data_.i32  = val;
51+   }
3552
36-   Scalar (bool  val) : tag(Tag::HAS_B) { data_.b  = val; }  //  NOLINT
53+   ScalarBase (int16_t  val) : dtype_(DataType::INT16) {  //  NOLINT
54+     data_.i16  = val;
55+   }
3756
38-   Scalar (const  std::string& str_value) : tag(Tag::HAS_D) {  //  NOLINT
57+   ScalarBase (int8_t  val) : dtype_(DataType::INT8) {  //  NOLINT
58+     data_.i8  = val;
59+   }
60+ 
61+   ScalarBase (uint64_t  val) : dtype_(DataType::UINT64) {  //  NOLINT
62+     data_.ui64  = val;
63+   }
64+ 
65+   ScalarBase (uint32_t  val) : dtype_(DataType::UINT32) {  //  NOLINT
66+     data_.ui32  = val;
67+   }
68+ 
69+   ScalarBase (uint16_t  val) : dtype_(DataType::UINT16) {  //  NOLINT
70+     data_.ui16  = val;
71+   }
72+ 
73+   ScalarBase (uint8_t  val) : dtype_(DataType::UINT8) {  //  NOLINT
74+     data_.ui8  = val;
75+   }
76+ 
77+   ScalarBase (bool  val) : dtype_(DataType::BOOL) {  //  NOLINT
78+     data_.b  = val;
79+   }
80+ 
81+   ScalarBase (complex64 val) : dtype_(DataType::COMPLEX64) {  //  NOLINT
82+     data_.c64  = val;
83+   }
84+ 
85+   ScalarBase (complex128 val) : dtype_(DataType::COMPLEX128) {  //  NOLINT
86+     data_.c128  = val;
87+   }
88+ 
89+   //  The compatible method for fliud operators,
90+   //  and it will be removed in the future.
91+   explicit  ScalarBase (const  std::string& str_value)
92+       : dtype_(DataType::FLOAT64) {
3993    if  (str_value == " inf" 
40-       data_.d  = std::numeric_limits<double >::infinity ();
94+       data_.f64  = std::numeric_limits<double >::infinity ();
4195    } else  if  (str_value == " -inf" 
42-       data_.d  = -std::numeric_limits<double >::infinity ();
96+       data_.f64  = -std::numeric_limits<double >::infinity ();
4397    } else  if  (str_value == " nan" 
44-       data_.d  = std::numeric_limits<double >::quiet_NaN ();
98+       data_.f64  = std::numeric_limits<double >::quiet_NaN ();
4599    } else  {
46-       data_.d  = std::stod (str_value);
100+       data_.f64  = std::stod (str_value);
101+     }
102+   }
103+ 
104+   //  The Tensor must have one dim
105+   ScalarBase (const  T& tensor) : dtype_(tensor.dtype()) {  //  NOLINT
106+     PD_CHECK (
107+         tensor.numel () == 1 ,
108+         " The Scalar only supports Tensor with 1 element, but now Tensor has `" 
109+         tensor.numel (),
110+         " ` element." 
111+     switch  (dtype_) {
112+       case  DataType::FLOAT32:
113+         data_.f32  = tensor.template  data <float >()[0 ];
114+         break ;
115+       case  DataType::FLOAT64:
116+         data_.f64  = tensor.template  data <double >()[0 ];
117+         break ;
118+       case  DataType::FLOAT16:
119+         data_.f16  = tensor.template  data <float16>()[0 ];
120+         break ;
121+       case  DataType::BFLOAT16:
122+         data_.bf16  = tensor.template  data <bfloat16>()[0 ];
123+         break ;
124+       case  DataType::INT32:
125+         data_.i32  = tensor.template  data <int32_t >()[0 ];
126+         break ;
127+       case  DataType::INT64:
128+         data_.i64  = tensor.template  data <int64_t >()[0 ];
129+         break ;
130+       case  DataType::INT16:
131+         data_.i16  = tensor.template  data <int16_t >()[0 ];
132+         break ;
133+       case  DataType::INT8:
134+         data_.i8  = tensor.template  data <int8_t >()[0 ];
135+         break ;
136+       case  DataType::UINT16:
137+         data_.ui16  = tensor.template  data <uint16_t >()[0 ];
138+         break ;
139+       case  DataType::UINT8:
140+         data_.ui8  = tensor.template  data <uint8_t >()[0 ];
141+         break ;
142+       case  DataType::BOOL:
143+         data_.b  = tensor.template  data <bool >()[0 ];
144+         break ;
145+       case  DataType::COMPLEX64:
146+         data_.c64  = tensor.template  data <complex64>()[0 ];
147+         break ;
148+       case  DataType::COMPLEX128:
149+         data_.c128  = tensor.template  data <complex128>()[0 ];
150+         break ;
151+       default :
152+         PD_THROW (" Invalid tensor data type `" " `." 
47153    }
48154  }
49155
50-   template  <typename  T>
51-   inline  T to () const  {
52-     switch  (tag) {
53-       case  Tag::HAS_F:
54-         return  static_cast <T>(data_.f );
55-       case  Tag::HAS_D:
56-         return  static_cast <T>(data_.d );
57-       case  Tag::HAS_I32:
58-         return  static_cast <T>(data_.i32 );
59-       case  Tag::HAS_I64:
60-         return  static_cast <T>(data_.i64 );
61-       case  Tag::HAS_B:
62-         return  static_cast <T>(data_.b );
156+   template  <typename  OtherT>
157+   ScalarBase (const  ScalarBase<OtherT>& other) {
158+     CopyScalar (other, this );
159+   }
160+ 
161+   template  <typename  RT>
162+   inline  RT to () const  {
163+     switch  (dtype_) {
164+       case  DataType::FLOAT32:
165+         return  static_cast <RT>(data_.f32 );
166+       case  DataType::FLOAT64:
167+         return  static_cast <RT>(data_.f64 );
168+       case  DataType::FLOAT16:
169+         return  static_cast <RT>(data_.f16 );
170+       case  DataType::BFLOAT16:
171+         return  static_cast <RT>(data_.bf16 );
172+       case  DataType::INT32:
173+         return  static_cast <RT>(data_.i32 );
174+       case  DataType::INT64:
175+         return  static_cast <RT>(data_.i64 );
176+       case  DataType::INT16:
177+         return  static_cast <RT>(data_.i16 );
178+       case  DataType::INT8:
179+         return  static_cast <RT>(data_.i8 );
180+       case  DataType::UINT16:
181+         return  static_cast <RT>(data_.ui16 );
182+       case  DataType::UINT8:
183+         return  static_cast <RT>(data_.ui8 );
184+       case  DataType::BOOL:
185+         return  static_cast <RT>(data_.b );
186+       case  DataType::COMPLEX64:
187+         return  static_cast <RT>(data_.c64 );
188+       case  DataType::COMPLEX128:
189+         return  static_cast <RT>(data_.c128 );
63190      default :
64-         PD_THROW (" Invalid enum scalar type tag  `" static_cast < int >(tag) , " `." 
191+         PD_THROW (" Invalid enum scalar data  type `" dtype_ , " `." 
65192    }
66193  }
67194
68195 private: 
69-   enum   class   Tag  { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B }; 
70-   Tag tag ;
196+   template  < typename  T1,  typename  T2> 
197+   friend   void   CopyScalar ( const  ScalarBase<T1>& src, ScalarBase<T2>* dst) ;
71198
199+  private: 
200+   DataType dtype_;
72201  union  data {
73-     float  f;
74-     double  d;
202+     bool  b;
203+     int8_t  i8 ;
204+     int16_t  i16 ;
75205    int32_t  i32 ;
76206    int64_t  i64 ;
77-     bool  b;
207+     uint8_t  ui8;
208+     uint16_t  ui16;
209+     uint32_t  ui32;
210+     uint64_t  ui64;
211+     bfloat16 bf16 ;
212+     float16 f16 ;
213+     float  f32 ;
214+     double  f64 ;
215+     complex64 c64;
216+     complex128 c128;
78217  } data_;
79218};
80219
220+ template  <typename  T1, typename  T2>
221+ void  CopyScalar (const  ScalarBase<T1>& src, ScalarBase<T2>* dst) {
222+   dst->dtype_  = src.dtype_ ;
223+   dst->data_ .c128  = src.data_ .c128 ;
224+ }
225+ 
226+ using  Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
227+ 
81228}  //  namespace experimental
82229}  //  namespace paddle
83230
84231namespace  pten  {
85- using  Scalar = paddle::experimental::Scalar;
232+ class  DenseTensor ;
233+ using  Scalar = paddle::experimental::ScalarBase<DenseTensor>;
86234}  //  namespace pten
0 commit comments