Skip to content

Commit aba4dbe

Browse files
committed
[XPU] reduce_xxx and broadcast_xxx use int64_t shape (PaddlePaddle#71361)
1 parent b0df879 commit aba4dbe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+266
-246
lines changed

paddle/phi/kernels/funcs/fused_gemm_epilogue_xpu.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ void ComputeFusedGemmEpilogueBackwardXPU(const phi::XPUContext& dev_ctx,
143143
XPUType* dbias_ptr;
144144
auto* dbias_tmp_ptr = dev_ctx.template Alloc<T>(dbias);
145145
dbias_ptr = reinterpret_cast<XPUType*>(dbias_tmp_ptr);
146-
r = xpu::reduce_sum(
147-
xpu_ctx, dout_fc_ptr, dbias_ptr, {info_forward.m, info_forward.n}, {0});
146+
r = xpu::reduce_sum(xpu_ctx,
147+
dout_fc_ptr,
148+
dbias_ptr,
149+
{(int64_t)info_forward.m, (int64_t)info_forward.n},
150+
{0LL});
148151
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
149152
}
150153
}

paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,11 @@ void FFNGrad(const phi::XPUContext& dev_ctx,
190190
dropout_param2,
191191
bsz_seq * d_model);
192192
// linear_grad2
193-
r = xpu::reduce_sum(
194-
xpu_ctx, d_dropout2_out_ptr, d_linear2_bias_ptr, {bsz_seq, d_model}, {0});
193+
r = xpu::reduce_sum(xpu_ctx,
194+
d_dropout2_out_ptr,
195+
d_linear2_bias_ptr,
196+
{(int64_t)bsz_seq, (int64_t)d_model},
197+
{0LL});
195198
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
196199

197200
phi::XpuFcInfo linear2_fc_info;
@@ -285,8 +288,8 @@ void FFNGrad(const phi::XPUContext& dev_ctx,
285288
r = xpu::reduce_sum(xpu_ctx,
286289
d_act_out_ptr,
287290
d_linear1_bias_ptr,
288-
{bsz_seq, dim_feedforward},
289-
{0});
291+
{(int64_t)bsz_seq, (int64_t)dim_feedforward},
292+
{0LL});
290293
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
291294

292295
phi::XpuFcInfo linear1_fc_info;

paddle/phi/kernels/legacy/xpu/compare_kernel.cc

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,25 @@
2323
namespace phi {
2424

2525
template <typename T, typename XPUType, typename Context>
26-
void XPUCompareRawKernelImpl(const Context& dev_ctx,
27-
const DenseTensor& x,
28-
const DenseTensor& y,
29-
DenseTensor* out,
30-
std::function<int(xpu::Context*,
31-
const XPUType*,
32-
const XPUType*,
33-
bool*,
34-
const std::vector<int>&,
35-
const std::vector<int>&)> func) {
36-
auto x_shape = common::vectorize<int>(x.dims());
37-
auto y_shape = common::vectorize<int>(y.dims());
26+
void XPUCompareRawKernelImpl(
27+
const Context& dev_ctx,
28+
const DenseTensor& x,
29+
const DenseTensor& y,
30+
DenseTensor* out,
31+
std::function<int(xpu::Context*,
32+
const XPUType*,
33+
const XPUType*,
34+
bool*,
35+
const std::vector<int64_t>&,
36+
const std::vector<int64_t>&)> func) {
37+
auto x_shape = common::vectorize<int64_t>(x.dims());
38+
auto y_shape = common::vectorize<int64_t>(y.dims());
3839

3940
if (x.dims().size() == 0) {
40-
x_shape = std::vector<int>({1});
41+
x_shape = std::vector<int64_t>({1});
4142
}
4243
if (y.dims().size() == 0) {
43-
y_shape = std::vector<int>({1});
44+
y_shape = std::vector<int64_t>({1});
4445
}
4546

4647
auto x_data = reinterpret_cast<const XPUType*>(x.data<T>());
@@ -64,8 +65,8 @@ void XPUCompareRawKernelImpl(const Context& dev_ctx,
6465
const XPUType* x, \
6566
const XPUType* y, \
6667
bool* z, \
67-
const std::vector<int>& xshape, \
68-
const std::vector<int>& yshape) { \
68+
const std::vector<int64_t>& xshape, \
69+
const std::vector<int64_t>& yshape) { \
6970
return functor(ctx, x, y, z, xshape, yshape); \
7071
}; \
7172
XPUCompareRawKernelImpl<T, XPUType, Context>(dev_ctx, x, y, out, f); \

paddle/phi/kernels/legacy/xpu/elementwise_add_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ void AddRawKernel(const Context& dev_ctx,
4040
const XPUType* x,
4141
const XPUType* y,
4242
XPUType* z,
43-
const std::vector<int>& xshape,
44-
const std::vector<int>& yshape) {
43+
const std::vector<int64_t>& xshape,
44+
const std::vector<int64_t>& yshape) {
4545
return xpu::broadcast_add<XPUType>(ctx, x, y, z, xshape, yshape);
4646
};
4747

paddle/phi/kernels/legacy/xpu/elementwise_divide_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ void DivideRawKernel(const Context& dev_ctx,
3535
const XPUType* x,
3636
const XPUType* y,
3737
XPUType* z,
38-
const std::vector<int>& xshape,
39-
const std::vector<int>& yshape) {
38+
const std::vector<int64_t>& xshape,
39+
const std::vector<int64_t>& yshape) {
4040
return xpu::broadcast_div<XPUType>(ctx, x, y, z, xshape, yshape);
4141
};
4242

paddle/phi/kernels/legacy/xpu/elementwise_kernel.cc

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ void MaximumRawKernel(const Context& dev_ctx,
3030
const XPUType* x,
3131
const XPUType* y,
3232
XPUType* z,
33-
const std::vector<int>& xshape,
34-
const std::vector<int>& yshape) {
33+
const std::vector<int64_t>& xshape,
34+
const std::vector<int64_t>& yshape) {
3535
return xpu::broadcast_max<XPUType>(ctx, x, y, z, xshape, yshape);
3636
};
3737

@@ -49,8 +49,8 @@ void MinimumRawKernel(const Context& dev_ctx,
4949
const XPUType* x,
5050
const XPUType* y,
5151
XPUType* z,
52-
const std::vector<int>& xshape,
53-
const std::vector<int>& yshape) {
52+
const std::vector<int64_t>& xshape,
53+
const std::vector<int64_t>& yshape) {
5454
return xpu::broadcast_min<XPUType>(ctx, x, y, z, xshape, yshape);
5555
};
5656

@@ -68,8 +68,8 @@ void RemainderRawKernel(const Context& dev_ctx,
6868
const XPUType* x,
6969
const XPUType* y,
7070
XPUType* z,
71-
const std::vector<int>& xshape,
72-
const std::vector<int>& yshape) {
71+
const std::vector<int64_t>& xshape,
72+
const std::vector<int64_t>& yshape) {
7373
return xpu::broadcast_mod<XPUType>(ctx, x, y, z, xshape, yshape);
7474
};
7575

@@ -87,8 +87,8 @@ void FloorDivideRawKernel(const Context& dev_ctx,
8787
const XPUType* x,
8888
const XPUType* y,
8989
XPUType* z,
90-
const std::vector<int>& xshape,
91-
const std::vector<int>& yshape) {
90+
const std::vector<int64_t>& xshape,
91+
const std::vector<int64_t>& yshape) {
9292
return xpu::broadcast_floordiv<XPUType>(ctx, x, y, z, xshape, yshape);
9393
};
9494

@@ -106,8 +106,8 @@ void ElementwisePowRawKernel(const Context& dev_ctx,
106106
const XPUType* x,
107107
const XPUType* y,
108108
XPUType* z,
109-
const std::vector<int>& xshape,
110-
const std::vector<int>& yshape) {
109+
const std::vector<int64_t>& xshape,
110+
const std::vector<int64_t>& yshape) {
111111
return xpu::broadcast_pow<XPUType>(ctx, x, y, z, xshape, yshape);
112112
};
113113

paddle/phi/kernels/legacy/xpu/elementwise_multiply_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ void MultiplyRawKernel(const Context& dev_ctx,
3535
const XPUType* x,
3636
const XPUType* y,
3737
XPUType* z,
38-
const std::vector<int>& xshape,
39-
const std::vector<int>& yshape) {
38+
const std::vector<int64_t>& xshape,
39+
const std::vector<int64_t>& yshape) {
4040
return xpu::broadcast_mul<XPUType>(ctx, x, y, z, xshape, yshape);
4141
};
4242

paddle/phi/kernels/legacy/xpu/elementwise_subtract_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ void SubtractRawKernel(const Context& dev_ctx,
3030
const XPUType* x,
3131
const XPUType* y,
3232
XPUType* z,
33-
const std::vector<int>& xshape,
34-
const std::vector<int>& yshape) {
33+
const std::vector<int64_t>& xshape,
34+
const std::vector<int64_t>& yshape) {
3535
return xpu::broadcast_sub<XPUType>(ctx, x, y, z, xshape, yshape);
3636
};
3737

paddle/phi/kernels/legacy/xpu/reduce_max_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ void MaxRawKernel(const Context& dev_ctx,
3333
auto f = [](xpu::Context* ctx,
3434
const T* x,
3535
T* y,
36-
const std::vector<int>& xdims,
37-
const std::vector<int>& reduce_dims) {
36+
const std::vector<int64_t>& xdims,
37+
const std::vector<int64_t>& reduce_dims) {
3838
return xpu::reduce_max<XPUType>(ctx,
3939
reinterpret_cast<const XPUType*>(x),
4040
reinterpret_cast<XPUType*>(y),

paddle/phi/kernels/xpu/activation_grad_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,11 @@ struct XPULogGradFunctor : public funcs::BaseActivationFunctor<T> {
180180
dev_ctx.x_context(), tmp, x->numel(), static_cast<T>(1.0));
181181
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
182182

183-
auto x_dims = common::vectorize<int>(x->dims());
183+
auto x_dims = common::vectorize<int64_t>(x->dims());
184184

185185
// use [1] to replace [], because xpu not support []
186186
if (x_dims.size() == 0) {
187-
x_dims = std::vector<int>({1});
187+
x_dims = std::vector<int64_t>({1});
188188
}
189189
// dx.device(d) = dout * (static_cast<T>(1) / x);
190190
r = xpu::broadcast_div(dev_ctx.x_context(),

0 commit comments

Comments
 (0)