Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/cumsum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class CumCUDAKernel : public framework::OpKernel<T> {
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
auto& dev_ctx = context.template device_context<DeviceContext>();
Tensor tmp;
framework::Tensor tmp;
tmp.Resize(out_dims);
auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
T* next_in_data = out_data;
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/math/tree2col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
std::vector<TreeNode> Tree2ColUtil::construct_patch(
size_t root, int max_depth, const std::vector<std::vector<int>> &tr) {
std::stack<TreeNode, std::deque<TreeNode>> stack;
Expand Down Expand Up @@ -51,7 +50,7 @@ std::vector<TreeNode> Tree2ColUtil::construct_patch(
return patch;
}

void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet,
void Tree2ColUtil::construct_tree(const framework::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count) {
auto edge_set_dims = EdgeSet.dims();
Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/operators/math/tree2col.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
namespace operators {
namespace math {
class TreeNode {
Expand Down Expand Up @@ -64,7 +62,7 @@ class Tree2ColUtil {
static std::vector<TreeNode> construct_patch(
size_t root, int max_depth, const std::vector<std::vector<int>> &tr);

static void construct_tree(const Tensor &EdgeSet,
static void construct_tree(const framework::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count);
};
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/api/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

// original custom op headers
#include "paddle/pten/api/ext/dispatch.h"
Expand Down
3 changes: 2 additions & 1 deletion paddle/pten/api/include/creation.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"

namespace paddle {
namespace experimental {

PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
PD_DLL_DECL Tensor full(const ScalarArray& shape,
const Scalar& value,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU,
Expand Down
9 changes: 5 additions & 4 deletions paddle/pten/api/lib/creation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@ PT_DECLARE_MODULE(CreationCUDA);
namespace paddle {
namespace experimental {

PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
PD_DLL_DECL Tensor full(const ScalarArray& shape,
const Scalar& value,
DataType dtype,
Backend backend,
DataLayout layout) {
// 1. Get kernel signature and kernel
pten::KernelKey kernel_key{backend, layout, dtype};
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"fill_constant.scalar", kernel_key);
"fill_constant", kernel_key);

// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);

// 3. Auto data transform
kernel_context.EmplaceBackAttr(value);
kernel_context.EmplaceBackAttr(pten::ScalarArray(shape));
kernel_context.EmplaceBackAttr(pten::Scalar(value));

// 4. InferShape
auto out_meta = pten::FullInferShape(shape, dtype, layout);
Expand Down Expand Up @@ -94,7 +95,7 @@ PD_DLL_DECL Tensor full_like(const Tensor& x,

// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackAttr(value);
kernel_context.EmplaceBackAttr(pten::Scalar(value));

// 4. InferShape
auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout);
Expand Down
3 changes: 3 additions & 0 deletions paddle/pten/api/lib/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,16 @@ template PD_DLL_DECL const int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL const uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL const int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL const int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL const uint16_t *Tensor::data<uint16_t>() const;
template PD_DLL_DECL const bool *Tensor::data<bool>() const;
template PD_DLL_DECL const paddle::platform::complex<float>
*Tensor::data<paddle::platform::complex<float>>() const;
template PD_DLL_DECL const paddle::platform::complex<double>
*Tensor::data<paddle::platform::complex<double>>() const;
template PD_DLL_DECL const paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL const paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;

template <typename T>
T *Tensor::data() {
Expand Down
212 changes: 180 additions & 32 deletions paddle/pten/common/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,69 +18,217 @@ limitations under the License. */
#include <limits>

#include "paddle/pten/api/ext/exception.h"

#include "paddle/pten/api/include/tensor.h"
namespace paddle {
namespace experimental {

class Scalar {
template <typename T>
Copy link
Contributor

@Shixiaowei02 Shixiaowei02 Nov 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了与之前的设计统一,推理希望优先使用继承而非模板,如果限于排期可暂时放松设计,后续再进行修改或给出文档说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前设计仅作为解决参数问题的临时方案,后续可在解决问题的前提下对方案进行调整完善或者重新设计

class ScalarBase {
public:
// Constructor support implicit
Scalar(float val) : tag(Tag::HAS_F) { data_.f = val; } // NOLINT
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
}

ScalarBase(float val) : dtype_(DataType::FLOAT32) { // NOLINT
data_.f32 = val;
}

ScalarBase(float16 val) : dtype_(DataType::FLOAT16) { // NOLINT
data_.f16 = val;
}

Scalar(double val) : tag(Tag::HAS_D) { data_.d = val; } // NOLINT
ScalarBase(bfloat16 val) : dtype_(DataType::BFLOAT16) { // NOLINT
data_.bf16 = val;
}

Scalar(int32_t val) : tag(Tag::HAS_I32) { data_.i32 = val; } // NOLINT
ScalarBase(int64_t val) : dtype_(DataType::INT64) { // NOLINT
data_.i64 = val;
}

Scalar(int64_t val) : tag(Tag::HAS_I64) { data_.i64 = val; } // NOLINT
ScalarBase(int32_t val) : dtype_(DataType::INT32) { // NOLINT
data_.i32 = val;
}

Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT
ScalarBase(int16_t val) : dtype_(DataType::INT16) { // NOLINT
data_.i16 = val;
}

Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT
ScalarBase(int8_t val) : dtype_(DataType::INT8) { // NOLINT
data_.i8 = val;
}

ScalarBase(uint64_t val) : dtype_(DataType::UINT64) { // NOLINT
data_.ui64 = val;
}

ScalarBase(uint32_t val) : dtype_(DataType::UINT32) { // NOLINT
data_.ui32 = val;
}

ScalarBase(uint16_t val) : dtype_(DataType::UINT16) { // NOLINT
data_.ui16 = val;
}

ScalarBase(uint8_t val) : dtype_(DataType::UINT8) { // NOLINT
data_.ui8 = val;
}

ScalarBase(bool val) : dtype_(DataType::BOOL) { // NOLINT
data_.b = val;
}

ScalarBase(complex64 val) : dtype_(DataType::COMPLEX64) { // NOLINT
data_.c64 = val;
}

ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT
data_.c128 = val;
}

// The compatible method for fliud operators,
// and it will be removed in the future.
explicit ScalarBase(const std::string& str_value)
: dtype_(DataType::FLOAT64) {
if (str_value == "inf") {
data_.d = std::numeric_limits<double>::infinity();
data_.f64 = std::numeric_limits<double>::infinity();
} else if (str_value == "-inf") {
data_.d = -std::numeric_limits<double>::infinity();
data_.f64 = -std::numeric_limits<double>::infinity();
} else if (str_value == "nan") {
data_.d = std::numeric_limits<double>::quiet_NaN();
data_.f64 = std::numeric_limits<double>::quiet_NaN();
} else {
data_.d = std::stod(str_value);
data_.f64 = std::stod(str_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里str_value 会是非法字符串吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的str_value是为了适配原来fill_constant op的参数,目前不会出现非法字符串的情况,完成op的迁移后这个接口应该会去掉,已补充注释

}
}

// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
tensor.numel(),
"` element.");
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT16:
data_.ui16 = tensor.template data<uint16_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}

template <typename T>
inline T to() const {
switch (tag) {
case Tag::HAS_F:
return static_cast<T>(data_.f);
case Tag::HAS_D:
return static_cast<T>(data_.d);
case Tag::HAS_I32:
return static_cast<T>(data_.i32);
case Tag::HAS_I64:
return static_cast<T>(data_.i64);
case Tag::HAS_B:
return static_cast<T>(data_.b);
template <typename OtherT>
ScalarBase(const ScalarBase<OtherT>& other) {
CopyScalar(other, this);
}

template <typename RT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能有类型安全问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[TODO] 在后续方案的优化迭代中进行完善

inline RT to() const {
switch (dtype_) {
case DataType::FLOAT32:
return static_cast<RT>(data_.f32);
case DataType::FLOAT64:
return static_cast<RT>(data_.f64);
case DataType::FLOAT16:
return static_cast<RT>(data_.f16);
case DataType::BFLOAT16:
return static_cast<RT>(data_.bf16);
case DataType::INT32:
return static_cast<RT>(data_.i32);
case DataType::INT64:
return static_cast<RT>(data_.i64);
case DataType::INT16:
return static_cast<RT>(data_.i16);
case DataType::INT8:
return static_cast<RT>(data_.i8);
case DataType::UINT16:
return static_cast<RT>(data_.ui16);
case DataType::UINT8:
return static_cast<RT>(data_.ui8);
case DataType::BOOL:
return static_cast<RT>(data_.b);
case DataType::COMPLEX64:
return static_cast<RT>(data_.c64);
case DataType::COMPLEX128:
return static_cast<RT>(data_.c128);
default:
PD_THROW("Invalid enum scalar type tag `", static_cast<int>(tag), "`.");
PD_THROW("Invalid enum scalar data type `", dtype_, "`.");
}
}

private:
enum class Tag { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B };
Tag tag;
template <typename T1, typename T2>
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);

private:
DataType dtype_;
union data {
float f;
double d;
bool b;
int8_t i8;
int16_t i16;
int32_t i32;
int64_t i64;
bool b;
uint8_t ui8;
uint16_t ui16;
uint32_t ui32;
uint64_t ui64;
bfloat16 bf16;
float16 f16;
float f32;
double f64;
complex64 c64;
complex128 c128;
} data_;
};

template <typename T1, typename T2>
void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
dst->dtype_ = src.dtype_;
dst->data_.c128 = src.data_.c128;
}

using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;

} // namespace experimental
} // namespace paddle

namespace pten {
using Scalar = paddle::experimental::Scalar;
class DenseTensor;
using Scalar = paddle::experimental::ScalarBase<DenseTensor>;
} // namespace pten
Loading