Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(reciprocal_grad,
ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

Expand Down
34 changes: 34 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,24 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct SqrtGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<ComplexType<T>>(0.5) / out).unaryExpr(Conj<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -4050,6 +4068,22 @@ struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
}
};

template <typename T>
struct CudaSqrtGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one_half = static_cast<ComplexType<T>>(0.5f);

// dx = dout * 0.5 / out
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return dout * conj(one_half / out);
}

static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};

template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_grad,
SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(stanh, StanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(softplus, SoftplusKernel)

Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,8 @@ def sqrt(x: Tensor, name: str | None = None) -> Tensor:
'int16',
'int32',
'int64',
'complex64',
'complex128',
],
'sqrt',
)
Expand Down
37 changes: 29 additions & 8 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,6 +1667,11 @@ def setUp(self):

np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.sqrt(x)

self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)}
Expand All @@ -1679,14 +1684,20 @@ def if_enable_cinn(self):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
check_pir_onednn=self.check_pir_onednn,
)
if self.dtype not in [np.complex64, np.complex128]:
self.check_grad(
['X'],
'Out',
check_prim=True,
check_pir=True,
check_prim_pir=True,
check_pir_onednn=self.check_pir_onednn,
)
else:
self.check_grad(
['X'],
'Out',
)

def test_check_output(self):
self.check_output(
Expand Down Expand Up @@ -1746,6 +1757,16 @@ def init_shape(self):
self.shape = []


class TestSqrt_Complex64(TestSqrt):
def init_dtype(self):
self.dtype = np.complex64


class TestSqrt_Complex128(TestSqrt):
def init_dtype(self):
self.dtype = np.complex128


@unittest.skipIf(
not core.is_compiled_with_cuda() or core.is_compiled_with_rocm(),
"core is not compiled with CUDA",
Expand Down
Loading