Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion paddle/fluid/operators/kron_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}

HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename T2>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么使用T2?可以考虑其他字母?T还有什么用呢?

HOSTDEVICE inline T2 operator()(const T2& x) const {
return x;
}
};

template <typename DeviceContext, typename T>
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/matmul_v2_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}

HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename T2>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

HOSTDEVICE inline T2 operator()(const T2& x) const {
return x;
}
};

template <typename DeviceContext, typename T>
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/operators/pool_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }

template <typename T2>
HOSTDEVICE inline T2 operator()(const T2& x) const {
return x * static_cast<T2>(n_inv);
}

private:
T n_inv;
Expand Down
158 changes: 114 additions & 44 deletions paddle/fluid/operators/reduce_ops/cub_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace cub = hipcub;

#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -66,39 +67,66 @@ struct Array {
T data_[ElementCount];
};

// reduce the 1d array to one element
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
__global__ void ReduceKernel1D(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init,
int reduce_num) {
int thread_id = blockIdx.x * blockDim.x + threadIdx.x;

typedef cub::BlockReduce<MPType, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;

MPType local_data = init;
for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x) {
local_data = static_cast<MPType>(
reducer(local_data, static_cast<MPType>(transformer(x[i]))));
}
__syncthreads();

local_data = BlockReduce(temp_storage).Reduce(local_data, reducer);

if (threadIdx.x == 0) {
y[blockIdx.x] = static_cast<Ty>(local_data);
}
}

// reduce the last axis of 2d array
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim>
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
TransformOp transformer, MPType init,
int reduce_num) {
__shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
__shared__
typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
MPType reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim)
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
reducer(reduce_var, static_cast<MPType>(transformer(x[idx_x + idx_y])));
__syncthreads();

reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
.Reduce(reduce_var, reducer);

if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
y[blockIdx.x] = static_cast<Ty>(reduce_var);
}
}

template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank>
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim, int Rank, int ReduceRank>
__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num,
Array<int, Rank> x_strides,
TransformOp transformer, MPType init,
int reduce_num, Array<int, Rank> x_strides,
Array<int, ReduceRank> reduce_dim,
Array<int, ReduceRank> reduce_strides,
Array<int, Rank - ReduceRank> left_dim,
Array<int, Rank - ReduceRank> left_strides) {
__shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
__shared__
typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
Array<int, Rank> sub_index;
int left_idx = blockIdx.x;
for (int i = 0; i < Rank - ReduceRank; ++i) {
Expand All @@ -114,7 +142,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,

int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
MPType reduce_var = static_cast<MPType>(transformer(x[idx_x]));

for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
int reduce_idx = i;
Expand All @@ -125,16 +153,16 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,

int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
reduce_var = static_cast<Ty>(
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
reduce_var = static_cast<MPType>(
reducer(reduce_var, static_cast<MPType>(transformer(x[idx_x]))));
}
__syncthreads();

reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
.Reduce(reduce_var, reducer);

if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
y[blockIdx.x] = static_cast<Ty>(reduce_var);
}
}

Expand Down Expand Up @@ -192,6 +220,53 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
}
}

template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
typename std::enable_if<!std::is_same<Tx, paddle::platform::float16>::value,
void>::type
LaunchCubReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const Ty& init,
int reduce_num, gpuStream_t stream) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transformer);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
}

template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
typename std::enable_if<std::is_same<Tx, paddle::platform::float16>::value,
void>::type
LaunchCubReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const MPType& init,
int reduce_num, gpuStream_t stream) {
int element_per_block = BlockDim * 10;
int block_per_grid = (reduce_num + element_per_block - 1) / element_per_block;

framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<MPType>(
framework::make_ddim(
{static_cast<int64_t>(block_per_grid * sizeof(MPType))}),
place);

// each block reduce number to interim result
ReduceKernel1D<Tx, MPType, MPType, ReduceOp, TransformOp,
BlockDim><<<block_per_grid, BlockDim, 0, stream>>>(
x_data, temp_storage, reducer, transformer, init, reduce_num);
// reduce all number to final result
ReduceKernel1D<MPType, MPType, Ty, ReduceOp, TransformOp,
BlockDim><<<1, BlockDim, 0, stream>>>(
temp_storage, y_data, reducer, transformer, init, block_per_grid);
}

template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp>
static void TensorReduceImpl(
Expand All @@ -201,45 +276,40 @@ static void TensorReduceImpl(
const std::vector<int>& reduce_dim, const std::vector<int>& reduce_strides,
const std::vector<int>& left_dim, const std::vector<int>& left_strides,
gpuStream_t stream) {
using MPType = typename details::MPTypeTrait<Tx>::Type;
MPType init_mp = static_cast<MPType>(init);

#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto kRank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break

#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
ReduceKernel<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, \
kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
x_data, y_data, reducer, transformer, init, reduce_num, \
Array<int, kRank>::From(x_strides), \
Array<int, kReduceRank>::From(reduce_dim), \
Array<int, kReduceRank>::From(reduce_strides), \
Array<int, kRank - kReduceRank>::From(left_dim), \
Array<int, kRank - kReduceRank>::From(left_strides)); \
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
ReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim, kRank, \
kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
x_data, y_data, reducer, transformer, init_mp, reduce_num, \
Array<int, kRank>::From(x_strides), \
Array<int, kReduceRank>::From(reduce_dim), \
Array<int, kReduceRank>::From(reduce_strides), \
Array<int, kRank - kReduceRank>::From(left_dim), \
Array<int, kRank - kReduceRank>::From(left_strides)); \
} break

int rank = x_strides.size();
int reduce_rank = reduce_strides.size();
if (rank == reduce_rank) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, transformer);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
LaunchCubReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim>(
x_data, y_data, place, reducer, transformer, init_mp, reduce_num,
stream);
return;
}
if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
ReduceKernel2D<Tx, Ty, ReduceOp, TransformOp,
ReduceKernel2D<Tx, MPType, Ty, ReduceOp, TransformOp,
BlockDim><<<left_num, BlockDim, 0, stream>>>(
x_data, y_data, reducer, transformer, init, reduce_num);
x_data, y_data, reducer, transformer, init_mp, reduce_num);
return;
}
/*
Expand Down
10 changes: 5 additions & 5 deletions paddle/fluid/operators/reduce_ops/reduce_sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@ REGISTER_OPERATOR(reduce_sum_grad, ops::ReduceGradOp,
ops::ReduceSumGradNoNeedBufferVarInferer);

REGISTER_OP_CPU_KERNEL(
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, bool,
reduce_sum, ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, float,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>,
Expand All @@ -130,9 +130,9 @@ using CPUReduceSumGradKernel =
ops::ReduceSumGradKernel<paddle::platform::CPUDeviceContext, T,
ops::SumGradFunctor, true>;

REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<float>,
REGISTER_OP_CPU_KERNEL(reduce_sum_grad, CPUReduceSumGradKernel<float>,
CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<paddle::platform::float16>,
CPUReduceSumGradKernel<int>,
CPUReduceSumGradKernel<int64_t>,
CPUReduceSumGradKernel<paddle::platform::complex64>,
Expand Down
12 changes: 8 additions & 4 deletions paddle/fluid/operators/reduce_ops/reduce_sum_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}

HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename T2>
HOSTDEVICE inline T2 operator()(const T2& x) const {
return x;
}
};

template <typename T>
Expand Down Expand Up @@ -70,9 +73,10 @@ class ReduceSumKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle

REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<bool>,
ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
REGISTER_OP_CUDA_KERNEL(reduce_sum, ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>,
ops::ReduceSumKernel<paddle::platform::float16>,
ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex64>,
ops::ReduceSumKernel<paddle::platform::complex128>);
4 changes: 2 additions & 2 deletions paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ using CUDAReduceSumGradKernel =
ops::ReduceGradKernel<paddle::platform::CUDADeviceContext, T,
ops::SumGradFunctor, true>;

REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>,
REGISTER_OP_CUDA_KERNEL(reduce_sum_grad, CUDAReduceSumGradKernel<float>,
CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<paddle::platform::float16>,
CUDAReduceSumGradKernel<int>,
CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<paddle::platform::complex64>,
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/operators/trace_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}

HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename T2>
HOSTDEVICE inline T2 operator()(const T2& x) const {
return x;
}
};

template <typename DeviceContext, typename T>
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4406,7 +4406,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
if dim == None or dim == [] or len(dim) == len(input.shape) else False
}
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_sum')
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'reduce_sum')
helper = LayerHelper('reduce_sum', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out')


class TestSumOp_fp16(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {'X': np.random.random((5, 6, 10)).astype("float16")}
self.outputs = {'Out': self.inputs['X'].sum(axis=0)}

def test_check_output(self):
self.check_output(atol=5e-2)

# Because of the precision fp16, max_relative_error
# should be 0.15 here.
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.15)


class TestSumOp5D(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
Expand Down