Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/cumsum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class CumCUDAKernel : public framework::OpKernel<T> {
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
auto& dev_ctx = context.template device_context<DeviceContext>();
Tensor tmp;
framework::Tensor tmp;
tmp.Resize(out_dims);
auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
T* next_in_data = out_data;
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/math/tree2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
std::vector<TreeNode> Tree2ColUtil::construct_patch(
size_t root, int max_depth, const std::vector<std::vector<int>> &tr) {
std::stack<TreeNode, std::deque<TreeNode>> stack;
Expand Down Expand Up @@ -51,7 +50,7 @@ std::vector<TreeNode> Tree2ColUtil::construct_patch(
return patch;
}

void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet,
void Tree2ColUtil::construct_tree(const framework::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count) {
auto edge_set_dims = EdgeSet.dims();
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/operators/math/tree2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
namespace operators {
namespace math {
class TreeNode {
Expand Down Expand Up @@ -64,7 +62,7 @@ class Tree2ColUtil {
static std::vector<TreeNode> construct_patch(
size_t root, int max_depth, const std::vector<std::vector<int>> &tr);

static void construct_tree(const Tensor &EdgeSet,
static void construct_tree(const framework::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count);
};
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/api/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

// original custom op headers
#include "paddle/pten/api/ext/dispatch.h"
Expand Down
3 changes: 2 additions & 1 deletion paddle/pten/api/include/creation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

namespace paddle {
namespace experimental {

PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
PD_DLL_DECL Tensor full(const ScalarArray& shape,
const Scalar& value,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU,
Expand Down
9 changes: 5 additions & 4 deletions paddle/pten/api/lib/creation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,23 @@ limitations under the License. */
namespace paddle {
namespace experimental {

PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
PD_DLL_DECL Tensor full(const ScalarArray& shape,
const Scalar& value,
DataType dtype,
Backend backend,
DataLayout layout) {
// 1. Get kernel signature and kernel
pten::KernelKey kernel_key{backend, layout, dtype};
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"fill_constant.scalar", kernel_key);
"fill_constant", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);

// 3. Auto data transform
kernel_context.EmplaceBackAttr(value);
kernel_context.EmplaceBackAttr(pten::ScalarArray(shape));
kernel_context.EmplaceBackAttr(pten::Scalar(value));

// 4. InferShape
auto out_meta = pten::FullInferShape(shape, dtype, layout);
Expand Down Expand Up @@ -87,7 +88,7 @@ PD_DLL_DECL Tensor full_like(const Tensor& x,

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackAttr(value);
kernel_context.EmplaceBackAttr(pten::Scalar(value));

// 4. InferShape
auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout);
Expand Down
3 changes: 3 additions & 0 deletions paddle/pten/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,16 @@ template PD_DLL_DECL const int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL const uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL const int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL const int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL const uint16_t *Tensor::data<uint16_t>() const;
template PD_DLL_DECL const bool *Tensor::data<bool>() const;
template PD_DLL_DECL const paddle::platform::complex<float>
*Tensor::data<paddle::platform::complex<float>>() const;
template PD_DLL_DECL const paddle::platform::complex<double>
*Tensor::data<paddle::platform::complex<double>>() const;
template PD_DLL_DECL const paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL const paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;

template <typename T>
T *Tensor::data() {
Expand Down
212 changes: 180 additions & 32 deletions paddle/pten/common/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,217 @@ limitations under the License. */
#include <limits>

#include "paddle/pten/api/ext/exception.h"

#include "paddle/pten/api/include/tensor.h"
namespace paddle {
namespace experimental {

class Scalar {
template <typename T>
Copy link
Contributor

@Shixiaowei02 Shixiaowei02 Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了与之前的设计统一,推理希望优先使用继承而非模板,如果限于排期可暂时放松设计,后续再进行修改或给出文档说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前设计仅作为解决参数问题的临时方案,后续可在解决问题的前提下对方案进行调整完善或者重新设计

class ScalarBase {
public:
// Constructor support implicit
Scalar(float val) : tag(Tag::HAS_F) { data_.f = val; } // NOLINT
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
}

ScalarBase(float val) : dtype_(DataType::FLOAT32) { // NOLINT
data_.f32 = val;
}

ScalarBase(float16 val) : dtype_(DataType::FLOAT16) { // NOLINT
data_.f16 = val;
}

Scalar(double val) : tag(Tag::HAS_D) { data_.d = val; } // NOLINT
ScalarBase(bfloat16 val) : dtype_(DataType::BFLOAT16) { // NOLINT
data_.bf16 = val;
}

Scalar(int32_t val) : tag(Tag::HAS_I32) { data_.i32 = val; } // NOLINT
ScalarBase(int64_t val) : dtype_(DataType::INT64) { // NOLINT
data_.i64 = val;
}

Scalar(int64_t val) : tag(Tag::HAS_I64) { data_.i64 = val; } // NOLINT
ScalarBase(int32_t val) : dtype_(DataType::INT32) { // NOLINT
data_.i32 = val;
}

Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT
ScalarBase(int16_t val) : dtype_(DataType::INT16) { // NOLINT
data_.i16 = val;
}

Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT
ScalarBase(int8_t val) : dtype_(DataType::INT8) { // NOLINT
data_.i8 = val;
}

ScalarBase(uint64_t val) : dtype_(DataType::UINT64) { // NOLINT
data_.ui64 = val;
}

ScalarBase(uint32_t val) : dtype_(DataType::UINT32) { // NOLINT
data_.ui32 = val;
}

ScalarBase(uint16_t val) : dtype_(DataType::UINT16) { // NOLINT
data_.ui16 = val;
}

ScalarBase(uint8_t val) : dtype_(DataType::UINT8) { // NOLINT
data_.ui8 = val;
}

ScalarBase(bool val) : dtype_(DataType::BOOL) { // NOLINT
data_.b = val;
}

ScalarBase(complex64 val) : dtype_(DataType::COMPLEX64) { // NOLINT
data_.c64 = val;
}

ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT
data_.c128 = val;
}

// The compatible method for fliud operators,
// and it will be removed in the future.
explicit ScalarBase(const std::string& str_value)
: dtype_(DataType::FLOAT64) {
if (str_value == "inf") {
data_.d = std::numeric_limits<double>::infinity();
data_.f64 = std::numeric_limits<double>::infinity();
} else if (str_value == "-inf") {
data_.d = -std::numeric_limits<double>::infinity();
data_.f64 = -std::numeric_limits<double>::infinity();
} else if (str_value == "nan") {
data_.d = std::numeric_limits<double>::quiet_NaN();
data_.f64 = std::numeric_limits<double>::quiet_NaN();
} else {
data_.d = std::stod(str_value);
data_.f64 = std::stod(str_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里str_value 会是非法字符串吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的str_value是为了适配原来fill_constant op的参数,目前不会出现非法字符串的情况,完成op的迁移后这个接口应该会去掉,已补充注释

}
}

// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
tensor.numel(),
"` element.");
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT16:
data_.ui16 = tensor.template data<uint16_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}

template <typename T>
inline T to() const {
switch (tag) {
case Tag::HAS_F:
return static_cast<T>(data_.f);
case Tag::HAS_D:
return static_cast<T>(data_.d);
case Tag::HAS_I32:
return static_cast<T>(data_.i32);
case Tag::HAS_I64:
return static_cast<T>(data_.i64);
case Tag::HAS_B:
return static_cast<T>(data_.b);
template <typename OtherT>
ScalarBase(const ScalarBase<OtherT>& other) {
CopyScalar(other, this);
}

template <typename RT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能有类型安全问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO] 在后续方案的优化迭代中进行完善

inline RT to() const {
switch (dtype_) {
case DataType::FLOAT32:
return static_cast<RT>(data_.f32);
case DataType::FLOAT64:
return static_cast<RT>(data_.f64);
case DataType::FLOAT16:
return static_cast<RT>(data_.f16);
case DataType::BFLOAT16:
return static_cast<RT>(data_.bf16);
case DataType::INT32:
return static_cast<RT>(data_.i32);
case DataType::INT64:
return static_cast<RT>(data_.i64);
case DataType::INT16:
return static_cast<RT>(data_.i16);
case DataType::INT8:
return static_cast<RT>(data_.i8);
case DataType::UINT16:
return static_cast<RT>(data_.ui16);
case DataType::UINT8:
return static_cast<RT>(data_.ui8);
case DataType::BOOL:
return static_cast<RT>(data_.b);
case DataType::COMPLEX64:
return static_cast<RT>(data_.c64);
case DataType::COMPLEX128:
return static_cast<RT>(data_.c128);
default:
PD_THROW("Invalid enum scalar type tag `", static_cast<int>(tag), "`.");
PD_THROW("Invalid enum scalar data type `", dtype_, "`.");
}
}

private:
enum class Tag { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B };
Tag tag;
template <typename T1, typename T2>
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);

private:
DataType dtype_;
union data {
float f;
double d;
bool b;
int8_t i8;
int16_t i16;
int32_t i32;
int64_t i64;
bool b;
uint8_t ui8;
uint16_t ui16;
uint32_t ui32;
uint64_t ui64;
bfloat16 bf16;
float16 f16;
float f32;
double f64;
complex64 c64;
complex128 c128;
} data_;
};

template <typename T1, typename T2>
void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
dst->dtype_ = src.dtype_;
dst->data_.c128 = src.data_.c128;
}

using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;

} // namespace experimental
} // namespace paddle

namespace pten {
using Scalar = paddle::experimental::Scalar;
class DenseTensor;
using Scalar = paddle::experimental::ScalarBase<DenseTensor>;
} // namespace pten
Loading