Skip to content

Conversation

@MingMingShangTian
Copy link
Contributor

@MingMingShangTian MingMingShangTian commented May 27, 2021

PR types

Others

PR changes

OPs

Describe

本PR将reduce_sum Op 的复数类型从complex64/128 替换为 complex<float/double> 模板类型。

由于模板类型中 从complex 转换为 实数类型 是explicit的,且原 IdentityFunctor 中不会改变数据类型,因此在编译过程中会遇到paddle::platform::complex<float> to floatpaddle::platform::complex<float> to int64 等类似的类型转换失败的报错内容。

(注: 由于原来的complex64 和complex128 类型 可以隐式地转为实数类型,因此即使 IdentityFunctor 中不会修改数据类型,编译仍可通过,但这种隐式转换会在不注意的时候导致虚部数据丢失)

报错的地方如下

framework::VisitDataTypeSmall(
     static_cast<framework::proto::VarType::Type>(out_dtype),
     TensorReduceFunctor<T, cub::Sum, IdentityFunctor<T>>(
        *input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(),
        IdentityFunctor<T>(), stream));

为了解决这个问题,本PR需要在IdentityFunctor中显示地指定数据类型转换, 主要修改了如下2个地方。

  1. IdentityFunctor 的模板参数个数从1个增加为2个,从而可以显示地支持不同数据类型间的转换。
    修改前:
template <typename T>
struct IdentityFunctor {
    HOSTDEVICE inline T operator()(const T& x) const { return x; }
};

修改后:

template <typename Tx, typename Ty = Tx>
struct IdentityFunctor {
   HOSTDEVICE inline Ty operator()(const Tx& x) const {
    return static_cast<Ty>(x);
  }
};
  1. IdentityFunctor 的特化和实例化 延后至 TensorReduceFunctor中的 apply()函数中,因为一直到apply中 IdentityFunctor 所需要的2个模板类型参数才能确定。
    修改前:
  void apply() const {
    const Ty& init_cast = static_cast<Ty>(init);
    TensorReduce<Tx, Ty, ReduceOp, TransformOp>(
        x, y, origin_reduce_dims, init_cast, reducer, transformer, stream);
  }

修改后:

  void apply() const {
    const Ty& init_cast = static_cast<Ty>(init);
    TensorReduce<Tx, Ty, ReduceOp, TransformOp<Tx, Ty>>(
        x, y, origin_reduce_dims, init_cast, reducer, TransformOp<Tx, Ty>(),
        stream);
  }

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MingMingShangTian MingMingShangTian merged commit 5756d3e into PaddlePaddle:develop May 28, 2021
@MingMingShangTian MingMingShangTian deleted the template_reduce_sum branch May 28, 2021 04:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants