Skip to content

Commit fef307a

Browse files
authored
【Paddle Tensor 规范化第二期】pow support complex (PaddlePaddle#71230)
1 parent a79a9aa commit fef307a

File tree

10 files changed

+543
-78
lines changed

10 files changed

+543
-78
lines changed

paddle/phi/kernels/cpu/activation_grad_kernel.cc

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/common/float16.h"
1919
#include "paddle/phi/core/kernel_registry.h"
20+
#include "paddle/phi/kernels/full_kernel.h"
2021
#include "paddle/phi/kernels/impl/activation_grad_impl.h"
2122

2223
namespace phi {
@@ -248,6 +249,35 @@ void HardSwishGradKernel(const Context& dev_ctx,
248249
dev_ctx, &x, nullptr, &dout, dx, functor);
249250
}
250251

252+
template <typename T, typename Context>
253+
void PowGradKernel(const Context& dev_ctx,
254+
const DenseTensor& x,
255+
const DenseTensor& dout,
256+
const Scalar& factor,
257+
DenseTensor* dx) {
258+
if (factor.to<float>() == 0) {
259+
std::vector<int64_t> vec_dims = common::vectorize(dx->dims());
260+
phi::Full<T, Context>(
261+
dev_ctx, phi::IntArray(vec_dims), static_cast<T>(0), dx);
262+
return;
263+
}
264+
PADDLE_ENFORCE_NOT_NULL(
265+
dx,
266+
errors::InvalidArgument("The output DenseTensor dx can not be nullptr"));
267+
dev_ctx.template Alloc<T>(dx);
268+
auto dout_flatten = EigenVector<T>::Flatten(
269+
GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad"));
270+
auto dx_flatten = EigenVector<T>::Flatten(
271+
GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad"));
272+
auto x_flatten =
273+
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad"));
274+
auto* place = dev_ctx.eigen_device();
275+
phi::funcs::PowGradFunctor<T> functor;
276+
auto attrs = functor.GetAttrs();
277+
*(attrs[0].second) = factor.to<float>();
278+
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
279+
}
280+
251281
} // namespace phi
252282

253283
PD_REGISTER_KERNEL(
@@ -462,20 +492,28 @@ PD_REGISTER_KERNEL(pow_grad,
462492
float,
463493
double,
464494
int,
465-
int64_t) {}
495+
int64_t,
496+
phi::dtype::complex<float>,
497+
phi::dtype::complex<double>) {}
498+
466499
PD_REGISTER_KERNEL(pow_double_grad,
467500
CPU,
468501
ALL_LAYOUT,
469502
phi::PowDoubleGradKernel,
470503
float,
471504
double,
472505
int,
473-
int64_t) {}
506+
int64_t,
507+
phi::dtype::complex<float>,
508+
phi::dtype::complex<double>) {}
509+
474510
PD_REGISTER_KERNEL(pow_triple_grad,
475511
CPU,
476512
ALL_LAYOUT,
477513
phi::PowTripleGradKernel,
478514
float,
479515
double,
480516
int,
481-
int64_t) {}
517+
int64_t,
518+
phi::dtype::complex<float>,
519+
phi::dtype::complex<double>) {}

paddle/phi/kernels/cpu/activation_kernel.cc

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/phi/backends/cpu/cpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
1920
#include "paddle/phi/kernels/funcs/activation_functor.h"
2021
#include "paddle/phi/kernels/impl/activation_impl.h"
2122

@@ -173,6 +174,35 @@ void RoundKernel(const Context& dev_ctx,
173174
dev_ctx, x, out, functor);
174175
}
175176

177+
template <typename T, typename Context>
178+
void PowKernel(const Context& dev_ctx,
179+
const DenseTensor& x,
180+
const Scalar& factor,
181+
DenseTensor* out) {
182+
PADDLE_ENFORCE_NOT_NULL(
183+
out, errors::InvalidArgument("Output Out should not be nullptr"));
184+
dev_ctx.template Alloc<T>(out);
185+
if (factor.to<float>() == 0) {
186+
std::vector<int64_t> vec_dims = common::vectorize(out->dims());
187+
phi::Full<T, Context>(
188+
dev_ctx, phi::IntArray(vec_dims), static_cast<T>(1), out);
189+
return;
190+
}
191+
if (factor.to<float>() == 1) {
192+
phi::Copy<Context>(dev_ctx, x, dev_ctx.GetPlace(), false, out);
193+
return;
194+
}
195+
auto x_flatten = phi::EigenVector<T>::Flatten(
196+
GET_DATA_SAFELY(&x, "Input", "X", "Activation"));
197+
auto out_flatten = phi::EigenVector<T>::Flatten(
198+
GET_DATA_SAFELY(out, "Output", "Out", "Activation"));
199+
auto* place = dev_ctx.eigen_device();
200+
phi::funcs::PowFunctor<T> functor;
201+
auto attrs = functor.GetAttrs();
202+
*(attrs[0].second) = factor.to<float>();
203+
functor(*place, x_flatten, out_flatten);
204+
}
205+
176206
} // namespace phi
177207
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
178208

@@ -215,6 +245,18 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
215245
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel)
216246
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
217247
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)
248+
PD_REGISTER_ACTIVATION_KERNEL(logit, LogitKernel)
249+
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel)
250+
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel)
251+
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
252+
PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
253+
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
254+
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
255+
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
256+
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
257+
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
258+
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
259+
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
218260

219261
PD_REGISTER_KERNEL(exp,
220262
CPU,
@@ -240,7 +282,6 @@ PD_REGISTER_KERNEL(expm1,
240282
phi::dtype::complex<float>,
241283
phi::dtype::complex<double>) {}
242284

243-
PD_REGISTER_KERNEL(logit, CPU, ALL_LAYOUT, phi::LogitKernel, float, double) {}
244285
PD_REGISTER_KERNEL(square,
245286
CPU,
246287
ALL_LAYOUT,
@@ -251,12 +292,6 @@ PD_REGISTER_KERNEL(square,
251292
int64_t,
252293
phi::dtype::complex<float>,
253294
phi::dtype::complex<double>) {}
254-
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softsign, SoftsignKernel)
255-
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sigmoid, SigmoidKernel)
256-
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
257-
PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel)
258-
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
259-
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
260295

261296
PD_REGISTER_KERNEL(log,
262297
CPU,
@@ -270,6 +305,7 @@ PD_REGISTER_KERNEL(log,
270305
phi::dtype::bfloat16,
271306
phi::dtype::complex<float>,
272307
phi::dtype::complex<double>) {}
308+
273309
PD_REGISTER_KERNEL(log2,
274310
CPU,
275311
ALL_LAYOUT,
@@ -282,6 +318,7 @@ PD_REGISTER_KERNEL(log2,
282318
phi::dtype::bfloat16,
283319
phi::dtype::complex<float>,
284320
phi::dtype::complex<double>) {}
321+
285322
PD_REGISTER_KERNEL(log10,
286323
CPU,
287324
ALL_LAYOUT,
@@ -294,6 +331,7 @@ PD_REGISTER_KERNEL(log10,
294331
phi::dtype::bfloat16,
295332
phi::dtype::complex<float>,
296333
phi::dtype::complex<double>) {}
334+
297335
PD_REGISTER_KERNEL(log1p,
298336
CPU,
299337
ALL_LAYOUT,
@@ -307,10 +345,6 @@ PD_REGISTER_KERNEL(log1p,
307345
phi::dtype::complex<float>,
308346
phi::dtype::complex<double>) {}
309347

310-
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel)
311-
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
312-
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
313-
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
314348
PD_REGISTER_KERNEL(negative,
315349
CPU,
316350
ALL_LAYOUT,
@@ -322,6 +356,14 @@ PD_REGISTER_KERNEL(negative,
322356
int64_t,
323357
phi::dtype::complex<float>,
324358
phi::dtype::complex<double>) {}
325-
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
326-
PD_REGISTER_KERNEL(
327-
pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {}
359+
360+
PD_REGISTER_KERNEL(pow,
361+
CPU,
362+
ALL_LAYOUT,
363+
phi::PowKernel,
364+
float,
365+
double,
366+
int,
367+
int64_t,
368+
phi::dtype::complex<float>,
369+
phi::dtype::complex<double>) {}

paddle/phi/kernels/funcs/activation_functor.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
#include "paddle/phi/common/amp_type_traits.h"
3333
#include "paddle/phi/common/bfloat16.h"
34+
#include "paddle/phi/common/complex.h"
3435
#include "paddle/phi/common/float16.h"
3536
#include "paddle/phi/core/dense_tensor.h"
3637
#include "paddle/phi/core/enforce.h"
@@ -2935,6 +2936,19 @@ struct PowFunctor : public BaseActivationFunctor<T> {
29352936
}
29362937
};
29372938

2939+
template <typename T>
2940+
struct PowFunctor<ComplexType<T>>
2941+
: public BaseActivationFunctor<ComplexType<T>> {
2942+
float factor;
2943+
typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
2944+
return {{"factor", &factor}};
2945+
}
2946+
template <typename Device, typename X, typename Out>
2947+
void operator()(Device d, X x, Out out) const {
2948+
out.device(d) = x.pow(static_cast<ComplexType<T>>(factor)); // NOLINT
2949+
}
2950+
};
2951+
29382952
template <typename T>
29392953
struct PowGradFunctor : public BaseActivationFunctor<T> {
29402954
float factor;
@@ -2954,6 +2968,27 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
29542968
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
29552969
};
29562970

2971+
template <typename T>
2972+
struct PowGradFunctor<ComplexType<T>>
2973+
: public BaseActivationFunctor<ComplexType<T>> {
2974+
float factor;
2975+
typename BaseActivationFunctor<ComplexType<T>>::AttrPair GetAttrs() {
2976+
return {{"factor", &factor}};
2977+
}
2978+
template <typename Device,
2979+
typename X,
2980+
typename Out,
2981+
typename dOut,
2982+
typename dX>
2983+
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
2984+
dx.device(d) =
2985+
dout * static_cast<ComplexType<T>>(factor) *
2986+
x.pow(static_cast<ComplexType<T>>(factor - 1)).unaryExpr(Conj<T>());
2987+
}
2988+
2989+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
2990+
};
2991+
29572992
// floor(x) = flooring(x)
29582993
template <typename T>
29592994
struct FloorFunctor : public BaseActivationFunctor<T> {
@@ -5194,6 +5229,84 @@ struct CudaCeilFunctor : public BaseActivationFunctor<T> {
51945229
}
51955230
};
51965231

5232+
template <typename T, typename MPType>
5233+
__device__ __forceinline__
5234+
typename std::enable_if<std::is_integral<T>::value, T>::type
5235+
compute_pow(const T a, const T b) {
5236+
// TODO(wujionghao): A potential speed improvement is supporting different
5237+
// types in C++.
5238+
// On CUDAPlace, pow(3, 1) calls pow(float, float), and
5239+
// it will return a float number like 2.99... , which floor to 2
5240+
// when cast to int by default and it is wrong.
5241+
// Use llrint to cast it to the nearest integer, which is 3.
5242+
return llrint(pow(static_cast<double>(a), static_cast<double>(b)));
5243+
}
5244+
5245+
template <typename T, typename MPType>
5246+
__device__ __forceinline__
5247+
typename std::enable_if<!std::is_integral<T>::value, T>::type
5248+
compute_pow(const T a, const T b) {
5249+
MPType a_val = static_cast<MPType>(a);
5250+
MPType b_val = static_cast<MPType>(b);
5251+
return static_cast<T>(pow(a_val, b_val));
5252+
}
5253+
5254+
template <typename T>
5255+
struct CudaPowFunctor : public BaseActivationFunctor<T> {
5256+
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
5257+
float factor;
5258+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
5259+
return {{"factor", &factor}};
5260+
}
5261+
__device__ __forceinline__ T operator()(const T x) const {
5262+
return compute_pow<T, MT>(x, static_cast<T>(factor));
5263+
}
5264+
};
5265+
5266+
template <typename T>
5267+
struct CudaPowGradFunctor : public BaseActivationFunctor<T> {
5268+
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
5269+
float factor;
5270+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
5271+
return {{"factor", &factor}};
5272+
}
5273+
// dx = dout * n * pow(x, n - 1)
5274+
__device__ __forceinline__ T operator()(const T dout, const T x) const {
5275+
return dout * static_cast<T>(factor) *
5276+
compute_pow<T, MT>(x, static_cast<T>(factor - 1));
5277+
}
5278+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
5279+
};
5280+
5281+
template <typename T>
5282+
struct CudaPowFunctor<ComplexType<T>>
5283+
: public BaseActivationFunctor<ComplexType<T>> {
5284+
float factor;
5285+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
5286+
return {{"factor", &factor}};
5287+
}
5288+
__device__ __forceinline__ ComplexType<T> operator()(
5289+
const ComplexType<T> x) const {
5290+
return pow(x, static_cast<ComplexType<T>>(factor));
5291+
}
5292+
};
5293+
5294+
template <typename T>
5295+
struct CudaPowGradFunctor<ComplexType<T>>
5296+
: public BaseActivationFunctor<ComplexType<T>> {
5297+
float factor;
5298+
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
5299+
return {{"factor", &factor}};
5300+
}
5301+
// dx = dout * n * pow(x, n - 1)
5302+
__device__ __forceinline__ ComplexType<T> operator()(
5303+
const ComplexType<T> dout, const ComplexType<T> x) const {
5304+
return dout * conj(static_cast<ComplexType<T>>(factor) *
5305+
pow(x, static_cast<ComplexType<T>>(factor - 1)));
5306+
}
5307+
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
5308+
};
5309+
51975310
template <typename T>
51985311
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
51995312
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

paddle/phi/kernels/funcs/elementwise_functor.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -959,21 +959,18 @@ inline HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
959959
compute_pow(const T a, const T b) {
960960
// TODO(wujionghao): A potential speed improvement is supporting different
961961
// types in C++.
962-
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
962+
// On CUDAPlace, pow(3, 1) calls pow(float, float), and
963963
// it will return a float number like 2.99... , which floor to 2
964964
// when cast to int by default and it is wrong.
965965
// Use llrint to cast it to the nearest integer, which is 3.
966-
return std::llrint(std::pow(static_cast<double>(a), static_cast<double>(b)));
966+
return llrint(pow(static_cast<double>(a), static_cast<double>(b)));
967967
}
968968
template <typename T, typename MPType>
969969
inline HOSTDEVICE typename std::enable_if<!std::is_integral<T>::value, T>::type
970970
compute_pow(const T a, const T b) {
971971
MPType a_val = static_cast<MPType>(a);
972972
MPType b_val = static_cast<MPType>(b);
973-
#ifdef PADDLE_WITH_XPU_KP
974973
return static_cast<T>(pow(a_val, b_val));
975-
#endif
976-
return static_cast<T>(std::pow(a_val, b_val));
977974
}
978975
#else
979976
template <typename T, typename MPType>

0 commit comments

Comments
 (0)