Skip to content
14 changes: 5 additions & 9 deletions paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,13 @@
namespace paddle {
namespace operators {

template <ElementwiseType ET,
typename InT,
typename OutT,
typename Functor,
int NumOuts = 1>
template <typename OutT, typename Functor, int NumOuts = 1>
void LaunchElementwiseCudaKernel(
const KPDevice &ctx,
const std::vector<const phi::DenseTensor *> &ins,
std::vector<phi::DenseTensor *> *outs,
int axis,
Functor func) {
Functor func,
int axis = -1) {
std::vector<const phi::DenseTensor *> pt_inputs;
std::vector<phi::DenseTensor *> pt_outputs;
// TODO(YuanRisheng) *_tmp for cache DenseTensor, because the temporary
Expand All @@ -53,8 +49,8 @@ void LaunchElementwiseCudaKernel(
for (int i = 0; i < pt_outputs_tmp.size(); i++) {
pt_outputs.push_back(pt_outputs_tmp[i].get());
}
phi::funcs::BroadcastKernel<ET, InT, OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, axis, func);
phi::funcs::BroadcastKernel<OutT, Functor, NumOuts>(
ctx, pt_inputs, &pt_outputs, func, axis);
}

} // namespace operators
Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/operators/elementwise/elementwise_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
z->mutable_data<OutType>(ctx.GetPlace());
const auto &dev_ctx = ctx.template device_context<DeviceContext>();
phi::funcs::ElementwiseCompute<Functor, T, OutType>(
dev_ctx, *x, *y, axis, func, z);
dev_ctx, *x, *y, func, z, axis);
}

// FusedElemwiseAndAct
Expand Down Expand Up @@ -1596,7 +1596,7 @@ static inline std::vector<int> GetReduceDim(const framework::DDim &in,

#if defined(__NVCC__) || defined(__HIPCC__)

template <ElementwiseType ET, typename T, typename Functor>
template <typename T, typename Functor>
void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
const platform::Place &place,
int axis,
Expand All @@ -1605,20 +1605,19 @@ void GetGradXAndYOut(const phi::GPUContext &dev_ctx,
phi::DenseTensor *dx,
phi::DenseTensor *dy,
Functor func) {
phi::GetGradXAndYOut<ET, T, Functor>(
phi::GetGradXAndYOut<T, Functor>(
dev_ctx, place, axis, ins, *dout, dx, dy, func);
}

template <ElementwiseType ET, typename T, typename Functor>
template <typename T, typename Functor>
void GetGradXOrYOut(const phi::GPUContext &dev_ctx,
const platform::Place &place,
int axis,
std::vector<const phi::DenseTensor *> ins,
const phi::DenseTensor *dout,
phi::DenseTensor *dxy,
Functor func) {
phi::GetGradXOrYOut<ET, T, Functor>(
dev_ctx, place, axis, ins, *dout, dxy, func);
phi::GetGradXOrYOut<T, Functor>(dev_ctx, place, axis, ins, *dout, dxy, func);
}

#endif
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ limitations under the License. */
namespace paddle {
namespace operators {

using ElementwiseType = phi::ElementwiseType;

template <typename OutT, typename Functor, int NumOuts = 1>
void LaunchSameDimsElementwiseCudaKernel(
const KPDevice &ctx,
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ class AttnMatMul {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
}
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/fused/attn_gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ class AttnMatmulINT8 {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
Expand Down Expand Up @@ -139,8 +139,8 @@ class AttnMatmulINT8 {
// bias_out = output + bias
std::vector<const phi::DenseTensor*> ins = {output, bias};
std::vector<phi::DenseTensor*> outs = {bias_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
PADDLE_ENFORCE_EQ(cudaGetLastError(),
cudaSuccess,
platform::errors::Fatal(
Expand Down
22 changes: 10 additions & 12 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(dev_ctx_,
ins,
&outs,
phi::funcs::AddFunctor<T>(),
elewise_add_axis);

phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
Expand Down Expand Up @@ -432,12 +431,11 @@ class FMHARef {
ins.emplace_back(src_mask_tensor);
outs.emplace_back(src_mask_out_tensor);
int elewise_add_axis = -1;
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_,
ins,
&outs,
elewise_add_axis,
phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(dev_ctx_,
ins,
&outs,
phi::funcs::AddFunctor<T>(),
elewise_add_axis);

phi::SoftmaxForwardCUDAKernelDriver<T>(
dev_ctx_, *src_mask_out_tensor, softmax_axis, softmax_out_tensor);
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/fused/fused_gate_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -689,13 +689,13 @@ class FMHAGateRef {
std::vector<const phi::DenseTensor*> ins = {
qk_out, src_mask, nonbatched_bias};
std::vector<phi::DenseTensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kTernary, T, T>(
dev_ctx_, ins, &outs, -1, TernaryAddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, TernaryAddFunctor<T>());
} else {
std::vector<const phi::DenseTensor*> ins = {qk_out, src_mask};
std::vector<phi::DenseTensor*> outs = {qk_out};
phi::funcs::BroadcastKernel<phi::ElementwiseType::kBinary, T, T>(
dev_ctx_, ins, &outs, -1, phi::funcs::AddFunctor<T>());
phi::funcs::BroadcastKernel<T>(
dev_ctx_, ins, &outs, phi::funcs::AddFunctor<T>());
}
phi::SoftmaxForwardCUDAKernelDriver<T>(dev_ctx_, *qk_out, -1, softmax_out);
}
Expand Down
3 changes: 1 addition & 2 deletions paddle/fluid/operators/fused_token_prune_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
ins.emplace_back(attn);
ins.emplace_back(mask);
outs.emplace_back(&attn_tmp);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, AttnMaskFunctor<T>());
LaunchElementwiseCudaKernel<T>(dev_ctx, ins, &outs, AttnMaskFunctor<T>());

// 2. Reduce sum
const std::vector<int64_t> reduce_dims{1, 2};
Expand Down
11 changes: 5 additions & 6 deletions paddle/fluid/operators/reduce_ops/reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -836,12 +836,11 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
}

using MPType = typename kps::details::MPTypeTrait<T>::Type;
phi::ReduceGrad<T, TransformOp<T, MPType>>(
dev_ctx,
pt_d_out.get(),
pt_d_x.get(),
pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx,
pt_d_out.get(),
pt_d_x.get(),
pt_out_dtype,
TransformOp<T, MPType>(reduce_num));
}
};

Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/kernels/cpu/bitwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ limitations under the License. */

namespace phi {

#define DEFINE_BITWISE_KERNEL(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T, T>( \
dev_ctx, x, y, -1, func, out); \
#define DEFINE_BITWISE_KERNEL(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T>( \
dev_ctx, x, y, func, out); \
}

DEFINE_BITWISE_KERNEL(And)
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/compare_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ inline void CompareKernelImpl(const Context& ctx,
ctx.template Alloc<bool>(out);
if (x.dims().size() >= y.dims().size()) {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, axis, Functor(), out);
ctx, x, y, Functor(), out, axis);
} else {
funcs::ElementwiseCompute<InverseFunctor, T, bool>(
ctx, x, y, axis, InverseFunctor(), out);
ctx, x, y, InverseFunctor(), out, axis);
}
}

Expand All @@ -59,7 +59,7 @@ inline void CompareAllKernelImpl(const Context& ctx,
tmp_data[0] = Functor()(x.data<T>()[0], y.data<T>()[0]);
} else {
funcs::ElementwiseCompute<Functor, T, bool>(
ctx, x, y, 0, Functor(), &tmp);
ctx, x, y, Functor(), &tmp, 0);
}
auto tmp_flat = EigenVector<bool>::Flatten(tmp);
auto out_es = EigenScalar<bool>::From(*out);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/dirichlet_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ struct DirichletSampler<CPUContext, T> {
true,
false);

funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T, T>(
dev_ctx, gamma_samples, gamma_sum, -1, funcs::DivideFunctor<T>(), out);
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
}
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/elementwise_divide_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ void DivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::DivideFunctor<T>(), out);
dev_ctx, x, y, funcs::DivideFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseDivideFunctor<T>(), out, axis);
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions paddle/phi/kernels/cpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void MaximumRawKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MaximumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MaximumFunctor<T>(), out);
dev_ctx, x, y, funcs::MaximumFunctor<T>(), out, axis);
}

template <typename T, typename Context>
Expand All @@ -42,7 +42,7 @@ void MinimumRawKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::MinimumFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::MinimumFunctor<T>(), out);
dev_ctx, x, y, funcs::MinimumFunctor<T>(), out, axis);
}

template <typename T, typename Context>
Expand All @@ -57,10 +57,10 @@ void RemainderRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::RemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::RemainderFunctor<T>(), out);
dev_ctx, x, y, funcs::RemainderFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseRemainderFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseRemainderFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseRemainderFunctor<T>(), out, axis);
}
}

Expand All @@ -76,10 +76,10 @@ void FloorDivideRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::FloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::FloorDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::FloorDivideFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::InverseFloorDivideFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::InverseFloorDivideFunctor<T>(), out);
dev_ctx, x, y, funcs::InverseFloorDivideFunctor<T>(), out, axis);
}
}

Expand All @@ -95,10 +95,10 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
auto y_dims = y.dims();
if (x_dims.size() >= y_dims.size()) {
funcs::ElementwiseCompute<funcs::ElementwisePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwisePowFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwisePowFunctor<T>(), out, axis);
} else {
funcs::ElementwiseCompute<funcs::ElementwiseInversePowFunctor<T>, T>(
dev_ctx, x, y, axis, funcs::ElementwiseInversePowFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwiseInversePowFunctor<T>(), out, axis);
}
}

Expand All @@ -110,7 +110,7 @@ void HeavisideKernel(const Context& dev_ctx,
// allocate memory for out
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::ElementwiseHeavisideFunctor<T>, T>(
dev_ctx, x, y, -1, funcs::ElementwiseHeavisideFunctor<T>(), out);
dev_ctx, x, y, funcs::ElementwiseHeavisideFunctor<T>(), out);
}

} // namespace phi
Expand Down
Loading