@@ -16,8 +16,7 @@ limitations under the License. */
1616
1717#include < type_traits>
1818
19- #include " paddle/fluid/platform/complex128.h"
20- #include " paddle/fluid/platform/complex64.h"
19+ #include " paddle/fluid/platform/complex.h"
2120#include " paddle/fluid/platform/hostdevice.h"
2221
2322namespace paddle {
@@ -66,7 +65,10 @@ using select_t = typename select<Head, Tail...>::type;
6665template <typename T>
6766using Real =
6867 select_t <cond<std::is_same<T, platform::complex64>::value, float >,
69- cond<std::is_same<T, platform::complex128>::value, double >, T>;
68+ cond<std::is_same<T, platform::complex128>::value, double >,
69+ cond<std::is_same<T, platform::complex <float >>::value, float >,
70+ cond<std::is_same<T, platform::complex <double >>::value, double >,
71+ T>;
7072
7173template <typename T, typename RealT>
7274using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
@@ -76,14 +78,18 @@ template <typename T, typename RealT>
7678using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
7779
7880template <typename T>
79- using EnableComplex =
80- typename std::enable_if<std::is_same<T, platform::complex64>::value ||
81- std::is_same<T, platform::complex128>::value>::type;
81+ using EnableComplex = typename std::enable_if<
82+ std::is_same<T, platform::complex64>::value ||
83+ std::is_same<T, platform::complex128>::value ||
84+ std::is_same<T, platform::complex <float >>::value ||
85+ std::is_same<T, platform::complex <double >>::value>::type;
8286
8387template <typename T>
8488using DisableComplex = typename std::enable_if<
8589 !std::is_same<T, platform::complex64>::value &&
86- !std::is_same<T, platform::complex128>::value>::type;
90+ !std::is_same<T, platform::complex128>::value &&
91+ !std::is_same<T, platform::complex <float >>::value &&
92+ !std::is_same<T, platform::complex <double >>::value>::type;
8793
8894template <typename T, typename Enable = void >
8995struct RealFunctor ;
@@ -173,44 +179,45 @@ struct AbsGradFunctor {
173179};
174180
175181template <>
176- struct AbsGradFunctor <paddle::platform::complex64 > {
177- AbsGradFunctor (const float * dout, const paddle::platform::complex64 * x,
178- paddle::platform::complex64 * output, int64_t numel)
182+ struct AbsGradFunctor <paddle::platform::complex < float > > {
183+ AbsGradFunctor (const float * dout, const paddle::platform::complex < float > * x,
184+ paddle::platform::complex < float > * output, int64_t numel)
179185 : dout_(dout), x_(x), output_(output), numel_(numel) {}
180186
181187 HOSTDEVICE void operator ()(int64_t idx) const {
182- if (x_[idx] == paddle::platform::complex64 (0 )) {
183- output_[idx] = paddle::platform::complex64 (0 );
188+ if (x_[idx] == paddle::platform::complex < float > (0 )) {
189+ output_[idx] = paddle::platform::complex < float > (0 );
184190 } else {
185- output_[idx] = paddle::platform::complex64 (dout_[idx]) *
186- (x_[idx] / paddle::platform::complex64 (abs (x_[idx])));
191+ output_[idx] = paddle::platform::complex < float > (dout_[idx]) *
192+ (x_[idx] / paddle::platform::complex < float > (abs (x_[idx])));
187193 }
188194 }
189195
190196 const float * dout_;
191- const paddle::platform::complex64 * x_;
192- paddle::platform::complex64 * output_;
197+ const paddle::platform::complex < float > * x_;
198+ paddle::platform::complex < float > * output_;
193199 int64_t numel_;
194200};
195201
196202template <>
197- struct AbsGradFunctor <paddle::platform::complex128 > {
198- AbsGradFunctor (const double * dout, const paddle::platform::complex128 * x,
199- paddle::platform::complex128 * output, int64_t numel)
203+ struct AbsGradFunctor <paddle::platform::complex < double > > {
204+ AbsGradFunctor (const double * dout, const paddle::platform::complex < double > * x,
205+ paddle::platform::complex < double > * output, int64_t numel)
200206 : dout_(dout), x_(x), output_(output), numel_(numel) {}
201207
202208 HOSTDEVICE void operator ()(int64_t idx) const {
203- if (x_[idx] == paddle::platform::complex128 (0 )) {
204- output_[idx] = paddle::platform::complex128 (0 );
209+ if (x_[idx] == paddle::platform::complex < double > (0 )) {
210+ output_[idx] = paddle::platform::complex < double > (0 );
205211 } else {
206- output_[idx] = paddle::platform::complex128 (dout_[idx]) *
207- (x_[idx] / paddle::platform::complex128 (abs (x_[idx])));
212+ output_[idx] =
213+ paddle::platform::complex <double >(dout_[idx]) *
214+ (x_[idx] / paddle::platform::complex <double >(abs (x_[idx])));
208215 }
209216 }
210217
211218 const double * dout_;
212- const paddle::platform::complex128 * x_;
213- paddle::platform::complex128 * output_;
219+ const paddle::platform::complex < double > * x_;
220+ paddle::platform::complex < double > * output_;
214221 int64_t numel_;
215222};
216223
@@ -234,46 +241,46 @@ struct AbsGradGradFunctor {
234241};
235242
236243template <>
237- struct AbsGradGradFunctor <paddle::platform::complex128 > {
238- AbsGradGradFunctor (const paddle::platform::complex128 * ddx,
239- const paddle::platform::complex128 * x,
240- paddle::platform::complex128 * output, int64_t numel)
244+ struct AbsGradGradFunctor <paddle::platform::complex < double > > {
245+ AbsGradGradFunctor (const paddle::platform::complex < double > * ddx,
246+ const paddle::platform::complex < double > * x,
247+ paddle::platform::complex < double > * output, int64_t numel)
241248 : ddx_(ddx), x_(x), output_(output), numel_(numel) {}
242249
243250 HOSTDEVICE void operator ()(int64_t idx) const {
244- if (x_[idx] == paddle::platform::complex128 (0 )) {
245- output_[idx] = paddle::platform::complex128 (0 );
251+ if (x_[idx] == paddle::platform::complex < double > (0 )) {
252+ output_[idx] = paddle::platform::complex < double > (0 );
246253 } else {
247- output_[idx] = paddle::platform::complex128 (ddx_[idx]) * x_[idx] /
248- paddle::platform::complex128 (abs (x_[idx]));
254+ output_[idx] = paddle::platform::complex < double > (ddx_[idx]) * x_[idx] /
255+ paddle::platform::complex < double > (abs (x_[idx]));
249256 }
250257 }
251258
252- const paddle::platform::complex128 * ddx_;
253- const paddle::platform::complex128 * x_;
254- paddle::platform::complex128 * output_;
259+ const paddle::platform::complex < double > * ddx_;
260+ const paddle::platform::complex < double > * x_;
261+ paddle::platform::complex < double > * output_;
255262 int64_t numel_;
256263};
257264
258265template <>
259- struct AbsGradGradFunctor <paddle::platform::complex64 > {
260- AbsGradGradFunctor (const paddle::platform::complex64 * ddx,
261- const paddle::platform::complex64 * x,
262- paddle::platform::complex64 * output, int64_t numel)
266+ struct AbsGradGradFunctor <paddle::platform::complex < float > > {
267+ AbsGradGradFunctor (const paddle::platform::complex < float > * ddx,
268+ const paddle::platform::complex < float > * x,
269+ paddle::platform::complex < float > * output, int64_t numel)
263270 : ddx_(ddx), x_(x), output_(output), numel_(numel) {}
264271
265272 HOSTDEVICE void operator ()(int64_t idx) const {
266- if (x_[idx] == paddle::platform::complex64 (0 )) {
267- output_[idx] = paddle::platform::complex64 (0 );
273+ if (x_[idx] == paddle::platform::complex < float > (0 )) {
274+ output_[idx] = paddle::platform::complex < float > (0 );
268275 } else {
269- output_[idx] = paddle::platform::complex64 (ddx_[idx]) * x_[idx] /
270- paddle::platform::complex64 (abs (x_[idx]));
276+ output_[idx] = paddle::platform::complex < float > (ddx_[idx]) * x_[idx] /
277+ paddle::platform::complex < float > (abs (x_[idx]));
271278 }
272279 }
273280
274- const paddle::platform::complex64 * ddx_;
275- const paddle::platform::complex64 * x_;
276- paddle::platform::complex64 * output_;
281+ const paddle::platform::complex < float > * ddx_;
282+ const paddle::platform::complex < float > * x_;
283+ paddle::platform::complex < float > * output_;
277284 int64_t numel_;
278285};
279286template <typename T, typename Enable = void >
0 commit comments