Skip to content

Commit 0f24de8

Browse files
authored
【PTen】Add Scalar and ScalarArray in pten (#37409)
* add scalar and scalar_array * remove DenseTensor include from Scalar and ScalarArray * remove inner header from scalar_array * refactor the method of fill_constant and add some comment
1 parent 1659079 commit 0f24de8

File tree

17 files changed

+503
-47
lines changed

17 files changed

+503
-47
lines changed

paddle/fluid/operators/cumsum_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ class CumCUDAKernel : public framework::OpKernel<T> {
254254
dim3 transpose_grids((width + tile_size - 1) / tile_size,
255255
(height + tile_size - 1) / tile_size);
256256
auto& dev_ctx = context.template device_context<DeviceContext>();
257-
Tensor tmp;
257+
framework::Tensor tmp;
258258
tmp.Resize(out_dims);
259259
auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
260260
T* next_in_data = out_data;

paddle/fluid/operators/math/tree2col.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
namespace paddle {
2020
namespace operators {
2121
namespace math {
22-
using Tensor = framework::Tensor;
2322
std::vector<TreeNode> Tree2ColUtil::construct_patch(
2423
size_t root, int max_depth, const std::vector<std::vector<int>> &tr) {
2524
std::stack<TreeNode, std::deque<TreeNode>> stack;
@@ -51,7 +50,7 @@ std::vector<TreeNode> Tree2ColUtil::construct_patch(
5150
return patch;
5251
}
5352

54-
void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet,
53+
void Tree2ColUtil::construct_tree(const framework::Tensor &EdgeSet,
5554
std::vector<std::vector<int>> *tr,
5655
size_t *node_count) {
5756
auto edge_set_dims = EdgeSet.dims();

paddle/fluid/operators/math/tree2col.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
#include "paddle/fluid/operators/math/math_function.h"
2222

2323
namespace paddle {
24-
using Tensor = framework::Tensor;
25-
using DDim = framework::DDim;
2624
namespace operators {
2725
namespace math {
2826
class TreeNode {
@@ -64,7 +62,7 @@ class Tree2ColUtil {
6462
static std::vector<TreeNode> construct_patch(
6563
size_t root, int max_depth, const std::vector<std::vector<int>> &tr);
6664

67-
static void construct_tree(const Tensor &EdgeSet,
65+
static void construct_tree(const framework::Tensor &EdgeSet,
6866
std::vector<std::vector<int>> *tr,
6967
size_t *node_count);
7068
};

paddle/pten/api/all.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ limitations under the License. */
3737
#include "paddle/pten/common/data_type.h"
3838
#include "paddle/pten/common/layout.h"
3939
#include "paddle/pten/common/scalar.h"
40+
#include "paddle/pten/common/scalar_array.h"
4041

4142
// original custom op headers
4243
#include "paddle/pten/api/ext/dispatch.h"

paddle/pten/api/include/creation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818
#include "paddle/pten/common/backend.h"
1919
#include "paddle/pten/common/data_type.h"
2020
#include "paddle/pten/common/scalar.h"
21+
#include "paddle/pten/common/scalar_array.h"
2122

2223
namespace paddle {
2324
namespace experimental {
2425

25-
PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
26+
PD_DLL_DECL Tensor full(const ScalarArray& shape,
2627
const Scalar& value,
2728
DataType dtype = DataType::FLOAT32,
2829
Backend backend = Backend::CPU,

paddle/pten/api/lib/creation.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,23 @@ PT_DECLARE_MODULE(CreationCUDA);
3434
namespace paddle {
3535
namespace experimental {
3636

37-
PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
37+
PD_DLL_DECL Tensor full(const ScalarArray& shape,
3838
const Scalar& value,
3939
DataType dtype,
4040
Backend backend,
4141
DataLayout layout) {
4242
// 1. Get kernel signature and kernel
4343
pten::KernelKey kernel_key{backend, layout, dtype};
4444
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
45-
"fill_constant.scalar", kernel_key);
45+
"fill_constant", kernel_key);
4646

4747
// 2. Get Device Context
4848
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
4949
auto kernel_context = pten::KernelContext(dev_ctx);
5050

5151
// 3. Auto data transform
52-
kernel_context.EmplaceBackAttr(value);
52+
kernel_context.EmplaceBackAttr(pten::ScalarArray(shape));
53+
kernel_context.EmplaceBackAttr(pten::Scalar(value));
5354

5455
// 4. InferShape
5556
auto out_meta = pten::FullInferShape(shape, dtype, layout);
@@ -94,7 +95,7 @@ PD_DLL_DECL Tensor full_like(const Tensor& x,
9495

9596
// 3. Auto data transform
9697
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
97-
kernel_context.EmplaceBackAttr(value);
98+
kernel_context.EmplaceBackAttr(pten::Scalar(value));
9899

99100
// 4. InferShape
100101
auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout);

paddle/pten/api/lib/tensor.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,16 @@ template PD_DLL_DECL const int32_t *Tensor::data<int32_t>() const;
219219
template PD_DLL_DECL const uint8_t *Tensor::data<uint8_t>() const;
220220
template PD_DLL_DECL const int8_t *Tensor::data<int8_t>() const;
221221
template PD_DLL_DECL const int16_t *Tensor::data<int16_t>() const;
222+
template PD_DLL_DECL const uint16_t *Tensor::data<uint16_t>() const;
222223
template PD_DLL_DECL const bool *Tensor::data<bool>() const;
223224
template PD_DLL_DECL const paddle::platform::complex<float>
224225
*Tensor::data<paddle::platform::complex<float>>() const;
225226
template PD_DLL_DECL const paddle::platform::complex<double>
226227
*Tensor::data<paddle::platform::complex<double>>() const;
227228
template PD_DLL_DECL const paddle::platform::float16 *
228229
Tensor::data<paddle::platform::float16>() const;
230+
template PD_DLL_DECL const paddle::platform::bfloat16 *
231+
Tensor::data<paddle::platform::bfloat16>() const;
229232

230233
template <typename T>
231234
T *Tensor::data() {

paddle/pten/common/scalar.h

Lines changed: 180 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,69 +18,217 @@ limitations under the License. */
1818
#include <limits>
1919

2020
#include "paddle/pten/api/ext/exception.h"
21-
21+
#include "paddle/pten/api/include/tensor.h"
2222
namespace paddle {
2323
namespace experimental {
2424

25-
class Scalar {
25+
template <typename T>
26+
class ScalarBase {
2627
public:
2728
// Constructor support implicit
28-
Scalar(float val) : tag(Tag::HAS_F) { data_.f = val; } // NOLINT
29+
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
30+
data_.f64 = val;
31+
}
32+
33+
ScalarBase(float val) : dtype_(DataType::FLOAT32) { // NOLINT
34+
data_.f32 = val;
35+
}
36+
37+
ScalarBase(float16 val) : dtype_(DataType::FLOAT16) { // NOLINT
38+
data_.f16 = val;
39+
}
2940

30-
Scalar(double val) : tag(Tag::HAS_D) { data_.d = val; } // NOLINT
41+
ScalarBase(bfloat16 val) : dtype_(DataType::BFLOAT16) { // NOLINT
42+
data_.bf16 = val;
43+
}
3144

32-
Scalar(int32_t val) : tag(Tag::HAS_I32) { data_.i32 = val; } // NOLINT
45+
ScalarBase(int64_t val) : dtype_(DataType::INT64) { // NOLINT
46+
data_.i64 = val;
47+
}
3348

34-
Scalar(int64_t val) : tag(Tag::HAS_I64) { data_.i64 = val; } // NOLINT
49+
ScalarBase(int32_t val) : dtype_(DataType::INT32) { // NOLINT
50+
data_.i32 = val;
51+
}
3552

36-
Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT
53+
ScalarBase(int16_t val) : dtype_(DataType::INT16) { // NOLINT
54+
data_.i16 = val;
55+
}
3756

38-
Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT
57+
ScalarBase(int8_t val) : dtype_(DataType::INT8) { // NOLINT
58+
data_.i8 = val;
59+
}
60+
61+
ScalarBase(uint64_t val) : dtype_(DataType::UINT64) { // NOLINT
62+
data_.ui64 = val;
63+
}
64+
65+
ScalarBase(uint32_t val) : dtype_(DataType::UINT32) { // NOLINT
66+
data_.ui32 = val;
67+
}
68+
69+
ScalarBase(uint16_t val) : dtype_(DataType::UINT16) { // NOLINT
70+
data_.ui16 = val;
71+
}
72+
73+
ScalarBase(uint8_t val) : dtype_(DataType::UINT8) { // NOLINT
74+
data_.ui8 = val;
75+
}
76+
77+
ScalarBase(bool val) : dtype_(DataType::BOOL) { // NOLINT
78+
data_.b = val;
79+
}
80+
81+
ScalarBase(complex64 val) : dtype_(DataType::COMPLEX64) { // NOLINT
82+
data_.c64 = val;
83+
}
84+
85+
ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT
86+
data_.c128 = val;
87+
}
88+
89+
// The compatible method for fliud operators,
90+
// and it will be removed in the future.
91+
explicit ScalarBase(const std::string& str_value)
92+
: dtype_(DataType::FLOAT64) {
3993
if (str_value == "inf") {
40-
data_.d = std::numeric_limits<double>::infinity();
94+
data_.f64 = std::numeric_limits<double>::infinity();
4195
} else if (str_value == "-inf") {
42-
data_.d = -std::numeric_limits<double>::infinity();
96+
data_.f64 = -std::numeric_limits<double>::infinity();
4397
} else if (str_value == "nan") {
44-
data_.d = std::numeric_limits<double>::quiet_NaN();
98+
data_.f64 = std::numeric_limits<double>::quiet_NaN();
4599
} else {
46-
data_.d = std::stod(str_value);
100+
data_.f64 = std::stod(str_value);
101+
}
102+
}
103+
104+
// The Tensor must have one dim
105+
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
106+
PD_CHECK(
107+
tensor.numel() == 1,
108+
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
109+
tensor.numel(),
110+
"` element.");
111+
switch (dtype_) {
112+
case DataType::FLOAT32:
113+
data_.f32 = tensor.template data<float>()[0];
114+
break;
115+
case DataType::FLOAT64:
116+
data_.f64 = tensor.template data<double>()[0];
117+
break;
118+
case DataType::FLOAT16:
119+
data_.f16 = tensor.template data<float16>()[0];
120+
break;
121+
case DataType::BFLOAT16:
122+
data_.bf16 = tensor.template data<bfloat16>()[0];
123+
break;
124+
case DataType::INT32:
125+
data_.i32 = tensor.template data<int32_t>()[0];
126+
break;
127+
case DataType::INT64:
128+
data_.i64 = tensor.template data<int64_t>()[0];
129+
break;
130+
case DataType::INT16:
131+
data_.i16 = tensor.template data<int16_t>()[0];
132+
break;
133+
case DataType::INT8:
134+
data_.i8 = tensor.template data<int8_t>()[0];
135+
break;
136+
case DataType::UINT16:
137+
data_.ui16 = tensor.template data<uint16_t>()[0];
138+
break;
139+
case DataType::UINT8:
140+
data_.ui8 = tensor.template data<uint8_t>()[0];
141+
break;
142+
case DataType::BOOL:
143+
data_.b = tensor.template data<bool>()[0];
144+
break;
145+
case DataType::COMPLEX64:
146+
data_.c64 = tensor.template data<complex64>()[0];
147+
break;
148+
case DataType::COMPLEX128:
149+
data_.c128 = tensor.template data<complex128>()[0];
150+
break;
151+
default:
152+
PD_THROW("Invalid tensor data type `", dtype_, "`.");
47153
}
48154
}
49155

50-
template <typename T>
51-
inline T to() const {
52-
switch (tag) {
53-
case Tag::HAS_F:
54-
return static_cast<T>(data_.f);
55-
case Tag::HAS_D:
56-
return static_cast<T>(data_.d);
57-
case Tag::HAS_I32:
58-
return static_cast<T>(data_.i32);
59-
case Tag::HAS_I64:
60-
return static_cast<T>(data_.i64);
61-
case Tag::HAS_B:
62-
return static_cast<T>(data_.b);
156+
template <typename OtherT>
157+
ScalarBase(const ScalarBase<OtherT>& other) {
158+
CopyScalar(other, this);
159+
}
160+
161+
template <typename RT>
162+
inline RT to() const {
163+
switch (dtype_) {
164+
case DataType::FLOAT32:
165+
return static_cast<RT>(data_.f32);
166+
case DataType::FLOAT64:
167+
return static_cast<RT>(data_.f64);
168+
case DataType::FLOAT16:
169+
return static_cast<RT>(data_.f16);
170+
case DataType::BFLOAT16:
171+
return static_cast<RT>(data_.bf16);
172+
case DataType::INT32:
173+
return static_cast<RT>(data_.i32);
174+
case DataType::INT64:
175+
return static_cast<RT>(data_.i64);
176+
case DataType::INT16:
177+
return static_cast<RT>(data_.i16);
178+
case DataType::INT8:
179+
return static_cast<RT>(data_.i8);
180+
case DataType::UINT16:
181+
return static_cast<RT>(data_.ui16);
182+
case DataType::UINT8:
183+
return static_cast<RT>(data_.ui8);
184+
case DataType::BOOL:
185+
return static_cast<RT>(data_.b);
186+
case DataType::COMPLEX64:
187+
return static_cast<RT>(data_.c64);
188+
case DataType::COMPLEX128:
189+
return static_cast<RT>(data_.c128);
63190
default:
64-
PD_THROW("Invalid enum scalar type tag `", static_cast<int>(tag), "`.");
191+
PD_THROW("Invalid enum scalar data type `", dtype_, "`.");
65192
}
66193
}
67194

68195
private:
69-
enum class Tag { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B };
70-
Tag tag;
196+
template <typename T1, typename T2>
197+
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
71198

199+
private:
200+
DataType dtype_;
72201
union data {
73-
float f;
74-
double d;
202+
bool b;
203+
int8_t i8;
204+
int16_t i16;
75205
int32_t i32;
76206
int64_t i64;
77-
bool b;
207+
uint8_t ui8;
208+
uint16_t ui16;
209+
uint32_t ui32;
210+
uint64_t ui64;
211+
bfloat16 bf16;
212+
float16 f16;
213+
float f32;
214+
double f64;
215+
complex64 c64;
216+
complex128 c128;
78217
} data_;
79218
};
80219

220+
template <typename T1, typename T2>
221+
void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
222+
dst->dtype_ = src.dtype_;
223+
dst->data_.c128 = src.data_.c128;
224+
}
225+
226+
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
227+
81228
} // namespace experimental
82229
} // namespace paddle
83230

84231
namespace pten {
85-
using Scalar = paddle::experimental::Scalar;
232+
class DenseTensor;
233+
using Scalar = paddle::experimental::ScalarBase<DenseTensor>;
86234
} // namespace pten

0 commit comments

Comments
 (0)