Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions paddle/pten/api/lib/kernel_declare.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed

PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT);
#endif

#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT);
#endif
6 changes: 3 additions & 3 deletions paddle/pten/api/lib/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"

PT_DECLARE_KERNEL(copy, CPU);
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, CUDA);
PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT);
#endif

#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif

namespace paddle {
Expand Down
37 changes: 36 additions & 1 deletion paddle/pten/common/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ namespace experimental {
* in the future
*/
enum class Backend : uint8_t {
// kernel backend cannot be undefined
UNDEFINED = 0,

// basic kernel backend
Expand All @@ -54,6 +53,42 @@ enum class Backend : uint8_t {

// end of backend types
NUM_BACKENDS,

/**
* [ Why we need ALL in baisc kernel key member? ]
*
* For Tensor, ALL represents an illegal Backend, but for Kernel, some
* kernels may be device-independent by nature, such as reshape; and when
* and some kernels are also device-independent when implemented based on
* primitive API.
*
* In this case, we need to provide a more concise registration method,
* instead of registering the kernels for each device with almost
* repetitive code, we need one registration covers all situations,
* so if we provide the ALL field with Register the kernel in this statement.
*
* Of course, we have also considered solving this problem through different
* named macros, for example, if we define
*
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND
*
* Based on this design pattern, the dtype and layout also have the same
* requirements, this cause we need to define a series of macros
*
* PT_REGISTER_KERNEL_FOR_ALL_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE
*
* It makes the system of registering macros more complicated, we think
* this is not a simple design, so we still adopt the design of providing
* the ALL field.
*
* Note: ALL_BACKEND only used for Kernel registration and selection
*/
ALL_BACKEND = UNDEFINED,
};

inline std::ostream& operator<<(std::ostream& os, Backend backend) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/pten/common/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ enum class DataType {
FLOAT64,
COMPLEX64,
COMPLEX128,
NUM_DATA_TYPES
NUM_DATA_TYPES,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
};

inline size_t SizeOf(DataType data_type) {
Expand Down
8 changes: 4 additions & 4 deletions paddle/pten/common/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ namespace experimental {

enum class DataLayout {
UNDEFINED = 0,
ANY,
// TODO(chenweihang): keep ANY for compatibility, remove it later
ANY = UNDEFINED,
NHWC,
NCHW,
MKLDNN,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
};

inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
switch (layout) {
case DataLayout::UNDEFINED:
os << "Undefined";
break;
case DataLayout::ANY:
os << "Any";
break;
case DataLayout::NHWC:
os << "NHWC";
break;
Expand Down
891 changes: 444 additions & 447 deletions paddle/pten/core/kernel_registry.h

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions paddle/pten/kernels/cpu/creation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx,

PT_REGISTER_KERNEL(full_like,
CPU,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
Expand All @@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like,

PT_REGISTER_KERNEL(full,
CPU,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
Expand Down
12 changes: 9 additions & 3 deletions paddle/pten/kernels/cpu/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>;

PT_REGISTER_KERNEL(dot,
CPU,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
Expand All @@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot,
complex64,
complex128) {}

PT_REGISTER_KERNEL(
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
15 changes: 7 additions & 8 deletions paddle/pten/kernels/cpu/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx,

PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
double,
Expand All @@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CPU,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
Expand All @@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,

PT_REGISTER_KERNEL(cast,
CPU,
ANY,
ALL_LAYOUT,
pten::Cast,
float,
double,
Expand All @@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}

PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CPU,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
16 changes: 8 additions & 8 deletions paddle/pten/kernels/cpu/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>;

// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ANY,
ALL_LAYOUT,
pten::Scale,
float,
double,
Expand All @@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {}
PT_REGISTER_KERNEL(add,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseAdd,
float,
double,
Expand All @@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseSub,
float,
double,
Expand All @@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {}
PT_REGISTER_KERNEL(divide,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseDiv,
float,
double,
Expand All @@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseMul,
float,
double,
Expand All @@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {}
PT_REGISTER_KERNEL(sum,
CPU,
ANY,
ALL_LAYOUT,
pten::Sum,
bool,
float,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/cpu/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx,

} // namespace pten

PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
4 changes: 2 additions & 2 deletions paddle/pten/kernels/cuda/creation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx,

PT_REGISTER_KERNEL(full_like,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
Expand All @@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like,

PT_REGISTER_KERNEL(full,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/kernels/cuda/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>;

PT_REGISTER_KERNEL(dot,
CUDA,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
Expand All @@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot,

PT_REGISTER_KERNEL(matmul,
CUDA,
ANY,
ALL_LAYOUT,
pten::Matmul,
float,
double,
Expand Down
14 changes: 6 additions & 8 deletions paddle/pten/kernels/cuda/manipulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ using float16 = paddle::platform::float16;

PT_REGISTER_KERNEL(flatten,
CUDA,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
float16,
Expand All @@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
Expand All @@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \
CUDA, \
ANY, \
ALL_LAYOUT, \
pten::Cast, \
float, \
double, \
Expand All @@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif

PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CUDA,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {}
Loading