Skip to content

Commit 5fa44c3

Browse files
modify Ops to complex template (#33041)
* modify conj, real, imag OP to complex template * replace with complex template to dot Op * replace with complex template to Abs Op * add support for complex64 and complex128
1 parent 86ea8dc commit 5fa44c3

File tree

11 files changed

+103
-95
lines changed

11 files changed

+103
-95
lines changed

paddle/fluid/operators/abs_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,19 @@ REGISTER_OP_CPU_KERNEL(
164164
ops::AbsKernel<paddle::platform::CPUDeviceContext, int>,
165165
ops::AbsKernel<paddle::platform::CPUDeviceContext, int64_t>,
166166
ops::AbsKernel<paddle::platform::CPUDeviceContext,
167-
paddle::platform::complex64>,
167+
paddle::platform::complex<float>>,
168168
ops::AbsKernel<paddle::platform::CPUDeviceContext,
169-
paddle::platform::complex128>);
169+
paddle::platform::complex<double>>);
170170

171171
REGISTER_OP_CPU_KERNEL(
172172
abs_grad, ops::AbsGradKernel<paddle::platform::CPUDeviceContext, float>,
173173
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, double>,
174174
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int>,
175175
ops::AbsGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
176176
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
177-
paddle::platform::complex64>,
177+
paddle::platform::complex<float>>,
178178
ops::AbsGradKernel<paddle::platform::CPUDeviceContext,
179-
paddle::platform::complex128>);
179+
paddle::platform::complex<double>>);
180180

181181
REGISTER_OP_CPU_KERNEL(
182182
abs_grad_grad,
@@ -187,6 +187,6 @@ REGISTER_OP_CPU_KERNEL(
187187
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
188188
paddle::platform::float16>,
189189
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
190-
paddle::platform::complex64>,
190+
paddle::platform::complex<float>>,
191191
ops::AbsDoubleGradKernel<paddle::platform::CPUDeviceContext,
192-
paddle::platform::complex128>);
192+
paddle::platform::complex<double>>);

paddle/fluid/operators/abs_op.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,23 @@ REGISTER_OP_CUDA_KERNEL(
7070
ops::AbsKernel<plat::CUDADeviceContext, int>,
7171
ops::AbsKernel<plat::CUDADeviceContext, int64_t>,
7272
ops::AbsKernel<plat::CUDADeviceContext, plat::float16>,
73-
ops::AbsKernel<plat::CUDADeviceContext, plat::complex64>,
74-
ops::AbsKernel<plat::CUDADeviceContext, plat::complex128>);
73+
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<float>>,
74+
ops::AbsKernel<plat::CUDADeviceContext, plat::complex<double>>);
7575

7676
REGISTER_OP_CUDA_KERNEL(
7777
abs_grad, ops::AbsGradKernel<plat::CUDADeviceContext, float>,
7878
ops::AbsGradKernel<plat::CUDADeviceContext, double>,
7979
ops::AbsGradKernel<plat::CUDADeviceContext, int>,
8080
ops::AbsGradKernel<plat::CUDADeviceContext, int64_t>,
8181
ops::AbsGradKernel<plat::CUDADeviceContext, plat::float16>,
82-
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex64>,
83-
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex128>);
82+
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
83+
ops::AbsGradKernel<plat::CUDADeviceContext, plat::complex<double>>);
8484

8585
REGISTER_OP_CUDA_KERNEL(
8686
abs_grad_grad, ops::AbsDoubleGradKernel<plat::CUDADeviceContext, float>,
8787
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, double>,
8888
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int>,
8989
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
9090
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
91-
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex64>,
92-
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex128>);
91+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<float>>,
92+
ops::AbsDoubleGradKernel<plat::CUDADeviceContext, plat::complex<double>>);

paddle/fluid/operators/conj_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ REGISTER_OPERATOR(conj, ops::ConjOp, ops::ConjOpMaker,
7878

7979
REGISTER_OP_CPU_KERNEL(
8080
conj, ops::ConjKernel<paddle::platform::CPUDeviceContext,
81-
paddle::platform::complex64>,
81+
paddle::platform::complex<float>>,
8282
ops::ConjKernel<paddle::platform::CPUDeviceContext,
83-
paddle::platform::complex128>,
83+
paddle::platform::complex<double>>,
8484
ops::ConjKernel<paddle::platform::CPUDeviceContext, float>,
8585
ops::ConjKernel<paddle::platform::CPUDeviceContext, double>,
8686
ops::ConjKernel<paddle::platform::CPUDeviceContext, int>,

paddle/fluid/operators/conj_op.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/conj_op.h"
16-
#include "paddle/fluid/platform/complex128.h"
17-
#include "paddle/fluid/platform/complex64.h"
16+
#include "paddle/fluid/platform/complex.h"
1817

1918
namespace ops = paddle::operators;
2019
REGISTER_OP_CUDA_KERNEL(
2120
conj, ops::ConjKernel<paddle::platform::CUDADeviceContext,
22-
paddle::platform::complex64>,
21+
paddle::platform::complex<float>>,
2322
ops::ConjKernel<paddle::platform::CUDADeviceContext,
24-
paddle::platform::complex128>,
23+
paddle::platform::complex<double>>,
2524
ops::ConjKernel<paddle::platform::CUDADeviceContext, float>,
2625
ops::ConjKernel<paddle::platform::CUDADeviceContext, double>,
2726
ops::ConjKernel<paddle::platform::CUDADeviceContext, int>,

paddle/fluid/operators/dot_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class DotOp : public framework::OperatorWithKernel {
3333
"Output(Out) of DotOp should not be null."));
3434

3535
auto x_dims = ctx->GetInputDim("X");
36-
auto x_rank = (size_t)x_dims.size();
36+
auto x_rank = static_cast<size_t>(x_dims.size());
3737
PADDLE_ENFORCE_EQ(true, 1 == x_rank || 2 == x_rank,
3838
platform::errors::PreconditionNotMet(
3939
"ShapeError: The dimensions of input tensor X (%s) "
@@ -154,15 +154,15 @@ REGISTER_OP_CPU_KERNEL(
154154
ops::DotKernel<paddle::platform::CPUDeviceContext, int>,
155155
ops::DotKernel<paddle::platform::CPUDeviceContext, int64_t>,
156156
ops::DotKernel<paddle::platform::CPUDeviceContext,
157-
paddle::platform::complex64>,
157+
paddle::platform::complex<float>>,
158158
ops::DotKernel<paddle::platform::CPUDeviceContext,
159-
paddle::platform::complex128>);
159+
paddle::platform::complex<double>>);
160160
REGISTER_OP_CPU_KERNEL(
161161
dot_grad, ops::DotGradKernel<paddle::platform::CPUDeviceContext, float>,
162162
ops::DotGradKernel<paddle::platform::CPUDeviceContext, double>,
163163
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int>,
164164
ops::DotGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
165165
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
166-
paddle::platform::complex64>,
166+
paddle::platform::complex<float>>,
167167
ops::DotGradKernel<paddle::platform::CPUDeviceContext,
168-
paddle::platform::complex128>);
168+
paddle::platform::complex<double>>);

paddle/fluid/operators/dot_op.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ REGISTER_OP_CUDA_KERNEL(
2222
ops::DotKernel<plat::CUDADeviceContext, double>,
2323
ops::DotKernel<plat::CUDADeviceContext, int>,
2424
ops::DotKernel<plat::CUDADeviceContext, int64_t>,
25-
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
26-
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
27-
REGISTER_OP_CUDA_KERNEL(
28-
dot_grad, ops::DotGradKernel<plat::CUDADeviceContext, float>,
29-
ops::DotGradKernel<plat::CUDADeviceContext, double>,
30-
ops::DotGradKernel<plat::CUDADeviceContext, int>,
31-
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
32-
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex64>,
33-
ops::DotGradKernel<plat::CUDADeviceContext, paddle::platform::complex128>);
25+
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<float>>,
26+
ops::DotKernel<plat::CUDADeviceContext, paddle::platform::complex<double>>);
27+
REGISTER_OP_CUDA_KERNEL(dot_grad,
28+
ops::DotGradKernel<plat::CUDADeviceContext, float>,
29+
ops::DotGradKernel<plat::CUDADeviceContext, double>,
30+
ops::DotGradKernel<plat::CUDADeviceContext, int>,
31+
ops::DotGradKernel<plat::CUDADeviceContext, int64_t>,
32+
ops::DotGradKernel<plat::CUDADeviceContext,
33+
paddle::platform::complex<float>>,
34+
ops::DotGradKernel<plat::CUDADeviceContext,
35+
paddle::platform::complex<double>>);

paddle/fluid/operators/imag_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
9696
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
9797

9898
REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
99-
paddle::platform::complex64>,
99+
paddle::platform::complex<float>>,
100100
ops::ImagKernel<paddle::platform::CPUDeviceContext,
101-
paddle::platform::complex128>);
101+
paddle::platform::complex<double>>);
102102
REGISTER_OP_CPU_KERNEL(imag_grad,
103103
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
104-
paddle::platform::complex64>,
104+
paddle::platform::complex<float>>,
105105
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
106-
paddle::platform::complex128>);
106+
paddle::platform::complex<double>>);

paddle/fluid/operators/imag_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ namespace ops = paddle::operators;
1818

1919
REGISTER_OP_CUDA_KERNEL(imag,
2020
ops::ImagKernel<paddle::platform::CUDADeviceContext,
21-
paddle::platform::complex64>,
21+
paddle::platform::complex<float>>,
2222
ops::ImagKernel<paddle::platform::CUDADeviceContext,
23-
paddle::platform::complex128>);
23+
paddle::platform::complex<double>>);
2424
REGISTER_OP_CUDA_KERNEL(imag_grad,
2525
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
26-
paddle::platform::complex64>,
26+
paddle::platform::complex<float>>,
2727
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
28-
paddle::platform::complex128>);
28+
paddle::platform::complex<double>>);

paddle/fluid/operators/math/complex_functors.h

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ limitations under the License. */
1616

1717
#include <type_traits>
1818

19-
#include "paddle/fluid/platform/complex128.h"
20-
#include "paddle/fluid/platform/complex64.h"
19+
#include "paddle/fluid/platform/complex.h"
2120
#include "paddle/fluid/platform/hostdevice.h"
2221

2322
namespace paddle {
@@ -66,7 +65,10 @@ using select_t = typename select<Head, Tail...>::type;
6665
template <typename T>
6766
using Real =
6867
select_t<cond<std::is_same<T, platform::complex64>::value, float>,
69-
cond<std::is_same<T, platform::complex128>::value, double>, T>;
68+
cond<std::is_same<T, platform::complex128>::value, double>,
69+
cond<std::is_same<T, platform::complex<float>>::value, float>,
70+
cond<std::is_same<T, platform::complex<double>>::value, double>,
71+
T>;
7072

7173
template <typename T, typename RealT>
7274
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
@@ -76,14 +78,18 @@ template <typename T, typename RealT>
7678
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
7779

7880
template <typename T>
79-
using EnableComplex =
80-
typename std::enable_if<std::is_same<T, platform::complex64>::value ||
81-
std::is_same<T, platform::complex128>::value>::type;
81+
using EnableComplex = typename std::enable_if<
82+
std::is_same<T, platform::complex64>::value ||
83+
std::is_same<T, platform::complex128>::value ||
84+
std::is_same<T, platform::complex<float>>::value ||
85+
std::is_same<T, platform::complex<double>>::value>::type;
8286

8387
template <typename T>
8488
using DisableComplex = typename std::enable_if<
8589
!std::is_same<T, platform::complex64>::value &&
86-
!std::is_same<T, platform::complex128>::value>::type;
90+
!std::is_same<T, platform::complex128>::value &&
91+
!std::is_same<T, platform::complex<float>>::value &&
92+
!std::is_same<T, platform::complex<double>>::value>::type;
8793

8894
template <typename T, typename Enable = void>
8995
struct RealFunctor;
@@ -173,44 +179,45 @@ struct AbsGradFunctor {
173179
};
174180

175181
template <>
176-
struct AbsGradFunctor<paddle::platform::complex64> {
177-
AbsGradFunctor(const float* dout, const paddle::platform::complex64* x,
178-
paddle::platform::complex64* output, int64_t numel)
182+
struct AbsGradFunctor<paddle::platform::complex<float>> {
183+
AbsGradFunctor(const float* dout, const paddle::platform::complex<float>* x,
184+
paddle::platform::complex<float>* output, int64_t numel)
179185
: dout_(dout), x_(x), output_(output), numel_(numel) {}
180186

181187
HOSTDEVICE void operator()(int64_t idx) const {
182-
if (x_[idx] == paddle::platform::complex64(0)) {
183-
output_[idx] = paddle::platform::complex64(0);
188+
if (x_[idx] == paddle::platform::complex<float>(0)) {
189+
output_[idx] = paddle::platform::complex<float>(0);
184190
} else {
185-
output_[idx] = paddle::platform::complex64(dout_[idx]) *
186-
(x_[idx] / paddle::platform::complex64(abs(x_[idx])));
191+
output_[idx] = paddle::platform::complex<float>(dout_[idx]) *
192+
(x_[idx] / paddle::platform::complex<float>(abs(x_[idx])));
187193
}
188194
}
189195

190196
const float* dout_;
191-
const paddle::platform::complex64* x_;
192-
paddle::platform::complex64* output_;
197+
const paddle::platform::complex<float>* x_;
198+
paddle::platform::complex<float>* output_;
193199
int64_t numel_;
194200
};
195201

196202
template <>
197-
struct AbsGradFunctor<paddle::platform::complex128> {
198-
AbsGradFunctor(const double* dout, const paddle::platform::complex128* x,
199-
paddle::platform::complex128* output, int64_t numel)
203+
struct AbsGradFunctor<paddle::platform::complex<double>> {
204+
AbsGradFunctor(const double* dout, const paddle::platform::complex<double>* x,
205+
paddle::platform::complex<double>* output, int64_t numel)
200206
: dout_(dout), x_(x), output_(output), numel_(numel) {}
201207

202208
HOSTDEVICE void operator()(int64_t idx) const {
203-
if (x_[idx] == paddle::platform::complex128(0)) {
204-
output_[idx] = paddle::platform::complex128(0);
209+
if (x_[idx] == paddle::platform::complex<double>(0)) {
210+
output_[idx] = paddle::platform::complex<double>(0);
205211
} else {
206-
output_[idx] = paddle::platform::complex128(dout_[idx]) *
207-
(x_[idx] / paddle::platform::complex128(abs(x_[idx])));
212+
output_[idx] =
213+
paddle::platform::complex<double>(dout_[idx]) *
214+
(x_[idx] / paddle::platform::complex<double>(abs(x_[idx])));
208215
}
209216
}
210217

211218
const double* dout_;
212-
const paddle::platform::complex128* x_;
213-
paddle::platform::complex128* output_;
219+
const paddle::platform::complex<double>* x_;
220+
paddle::platform::complex<double>* output_;
214221
int64_t numel_;
215222
};
216223

@@ -234,46 +241,46 @@ struct AbsGradGradFunctor {
234241
};
235242

236243
template <>
237-
struct AbsGradGradFunctor<paddle::platform::complex128> {
238-
AbsGradGradFunctor(const paddle::platform::complex128* ddx,
239-
const paddle::platform::complex128* x,
240-
paddle::platform::complex128* output, int64_t numel)
244+
struct AbsGradGradFunctor<paddle::platform::complex<double>> {
245+
AbsGradGradFunctor(const paddle::platform::complex<double>* ddx,
246+
const paddle::platform::complex<double>* x,
247+
paddle::platform::complex<double>* output, int64_t numel)
241248
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
242249

243250
HOSTDEVICE void operator()(int64_t idx) const {
244-
if (x_[idx] == paddle::platform::complex128(0)) {
245-
output_[idx] = paddle::platform::complex128(0);
251+
if (x_[idx] == paddle::platform::complex<double>(0)) {
252+
output_[idx] = paddle::platform::complex<double>(0);
246253
} else {
247-
output_[idx] = paddle::platform::complex128(ddx_[idx]) * x_[idx] /
248-
paddle::platform::complex128(abs(x_[idx]));
254+
output_[idx] = paddle::platform::complex<double>(ddx_[idx]) * x_[idx] /
255+
paddle::platform::complex<double>(abs(x_[idx]));
249256
}
250257
}
251258

252-
const paddle::platform::complex128* ddx_;
253-
const paddle::platform::complex128* x_;
254-
paddle::platform::complex128* output_;
259+
const paddle::platform::complex<double>* ddx_;
260+
const paddle::platform::complex<double>* x_;
261+
paddle::platform::complex<double>* output_;
255262
int64_t numel_;
256263
};
257264

258265
template <>
259-
struct AbsGradGradFunctor<paddle::platform::complex64> {
260-
AbsGradGradFunctor(const paddle::platform::complex64* ddx,
261-
const paddle::platform::complex64* x,
262-
paddle::platform::complex64* output, int64_t numel)
266+
struct AbsGradGradFunctor<paddle::platform::complex<float>> {
267+
AbsGradGradFunctor(const paddle::platform::complex<float>* ddx,
268+
const paddle::platform::complex<float>* x,
269+
paddle::platform::complex<float>* output, int64_t numel)
263270
: ddx_(ddx), x_(x), output_(output), numel_(numel) {}
264271

265272
HOSTDEVICE void operator()(int64_t idx) const {
266-
if (x_[idx] == paddle::platform::complex64(0)) {
267-
output_[idx] = paddle::platform::complex64(0);
273+
if (x_[idx] == paddle::platform::complex<float>(0)) {
274+
output_[idx] = paddle::platform::complex<float>(0);
268275
} else {
269-
output_[idx] = paddle::platform::complex64(ddx_[idx]) * x_[idx] /
270-
paddle::platform::complex64(abs(x_[idx]));
276+
output_[idx] = paddle::platform::complex<float>(ddx_[idx]) * x_[idx] /
277+
paddle::platform::complex<float>(abs(x_[idx]));
271278
}
272279
}
273280

274-
const paddle::platform::complex64* ddx_;
275-
const paddle::platform::complex64* x_;
276-
paddle::platform::complex64* output_;
281+
const paddle::platform::complex<float>* ddx_;
282+
const paddle::platform::complex<float>* x_;
283+
paddle::platform::complex<float>* output_;
277284
int64_t numel_;
278285
};
279286
template <typename T, typename Enable = void>

paddle/fluid/operators/real_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,11 @@ REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
9595
REGISTER_OPERATOR(real_grad, ops::RealGradOp);
9696

9797
REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
98-
paddle::platform::complex64>,
98+
paddle::platform::complex<float>>,
9999
ops::RealKernel<paddle::platform::CPUDeviceContext,
100-
paddle::platform::complex128>);
100+
paddle::platform::complex<double>>);
101101
REGISTER_OP_CPU_KERNEL(real_grad,
102102
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
103-
paddle::platform::complex64>,
103+
paddle::platform::complex<float>>,
104104
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
105-
paddle::platform::complex128>);
105+
paddle::platform::complex<double>>);

0 commit comments

Comments
 (0)