Skip to content

Commit a55fec0

Browse files
authored
[XPU] matmul support new shapes (#65963)
* support matmul with x_dim >=3, y_dim <= 2 and trans_x = True * add more tests
1 parent 321296c commit a55fec0

File tree

3 files changed

+98
-43
lines changed

3 files changed

+98
-43
lines changed

paddle/phi/kernels/xpu/matmul_grad_kernel.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ void MatmulGradKernel(const Context& dev_ctx,
6464
c_1 = new_c_1;
6565
}
6666

67+
if (info_forward.is_y_need_broadcast) {
68+
XPUType* new_c_2 = RAII_GUARD.alloc_l3_or_gm<XPUType>(
69+
info_forward.bs * info_forward.k * info_forward.n);
70+
PADDLE_ENFORCE_XDNN_NOT_NULL(new_c_2);
71+
c_2 = new_c_2;
72+
}
73+
6774
XpuFcInfo info_dx;
6875
XpuFcInfo info_dy;
6976
std::tuple<XpuFcInfo,
@@ -95,6 +102,15 @@ void MatmulGradKernel(const Context& dev_ctx,
95102
}
96103
if (dy) {
97104
MatMulXPUFunction<XPUType>(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f);
105+
if (info_forward.is_y_need_broadcast) {
106+
int r = xpu::reduce_sum<XPUType>(
107+
xpu_ctx,
108+
c_2,
109+
reinterpret_cast<XPUType*>(dy->data<T>()),
110+
{info_forward.bs, info_forward.k, info_forward.n},
111+
{0});
112+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
113+
}
98114
}
99115
}
100116

paddle/phi/kernels/xpu/xpu_api_wrapper.h

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct XpuFcInfo {
7979
float* max_out;
8080
const float* bias;
8181
bool is_x_need_broadcast;
82+
bool is_y_need_broadcast;
8283
const float* scale_x;
8384
const float* scale_y;
8485
int scale_x_mode;
@@ -99,6 +100,7 @@ struct XpuFcInfo {
99100
max_out(nullptr),
100101
bias(nullptr),
101102
is_x_need_broadcast(false),
103+
is_y_need_broadcast(false),
102104
scale_x(nullptr),
103105
scale_y(nullptr),
104106
scale_x_mode(0),
@@ -157,41 +159,16 @@ static void GetFCInfo(const phi::DDim& x_dims,
157159
auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(new_y_dims, 0, trans_y);
158160

159161
if (x_dims.size() >= 3 && y_dims.size() <= 2) {
160-
if (!trans_x) {
162+
if (!trans_x || mat_dim_a.batch_size_ == 1) {
161163
mat_dim_a.height_ *= mat_dim_a.batch_size_;
162164
mat_dim_a.batch_size_ = 0;
163165
} else {
164-
mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
165-
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
166+
info->is_y_need_broadcast = true;
166167
}
167168
}
168169

169170
if (y_dims.size() >= 3 && x_dims.size() <= 2) {
170-
PADDLE_ENFORCE_EQ(
171-
mat_dim_b.trans_,
172-
false,
173-
phi::errors::InvalidArgument(
174-
"xpu not support this Shape in matmul_op xdims = %s ydims = %s "
175-
"x_trans = %d y_trans = %d",
176-
x_dims.to_str(),
177-
y_dims.to_str(),
178-
mat_dim_a.trans_,
179-
mat_dim_b.trans_));
180-
if (mat_dim_a.width_ == mat_dim_b.batch_size_ * mat_dim_b.height_) {
181-
mat_dim_b.height_ *= mat_dim_b.batch_size_;
182-
mat_dim_b.batch_size_ = 0;
183-
} else {
184-
info->is_x_need_broadcast = true;
185-
}
186-
}
187-
188-
if (mat_dim_a.width_ == mat_dim_b.height_) {
189-
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
190-
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
191-
}
192-
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
193-
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
194-
}
171+
info->is_x_need_broadcast = (mat_dim_b.batch_size_ > 1);
195172
}
196173

197174
PADDLE_ENFORCE_EQ(mat_dim_a.width_,
@@ -204,6 +181,13 @@ static void GetFCInfo(const phi::DDim& x_dims,
204181
mat_dim_a.trans_,
205182
mat_dim_b.trans_));
206183

184+
if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
185+
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
186+
}
187+
if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
188+
mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
189+
}
190+
207191
info->m = mat_dim_a.height_;
208192
info->n = mat_dim_b.width_;
209193
info->k = mat_dim_a.width_;
@@ -572,6 +556,7 @@ static void MatMulXPUFunction(
572556
float* max_y = fcinfo.max_y;
573557
float* max_out = fcinfo.max_out;
574558
bool is_x_need_broadcast = fcinfo.is_x_need_broadcast;
559+
bool is_y_need_broadcast = fcinfo.is_y_need_broadcast;
575560
const float* bias = fcinfo.bias;
576561
const float* scale_x = fcinfo.scale_x;
577562
const float* scale_y = fcinfo.scale_y;
@@ -615,22 +600,35 @@ static void MatMulXPUFunction(
615600
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
616601
x_data = x_broadcast_data;
617602
}
603+
const XPUType* y_data = reinterpret_cast<const XPUType*>(y);
604+
if (is_y_need_broadcast) {
605+
XPUType* y_broadcast_data = nullptr;
606+
xpu::ctx_guard RAII_GUARD(xpu_ctx);
607+
y_broadcast_data = RAII_GUARD.alloc_l3_or_gm<XPUType>(batch_size * k * n);
608+
PADDLE_ENFORCE_XDNN_NOT_NULL(y_broadcast_data);
609+
std::vector<int> y_shape = {1, k, n};
610+
std::vector<int> new_y_shape = {batch_size, k, n};
611+
int r = xpu::broadcast<XPUType>(
612+
xpu_ctx, y_data, y_broadcast_data, y_shape, new_y_shape);
613+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
614+
y_data = y_broadcast_data;
615+
}
618616
// batch matmul
619-
xblas_fc_batch_api(xpu_ctx, // Context* ctx,
620-
batch_size, // int batch_size,
621-
trans_x, // bool x_trans,
622-
trans_y, // bool w_trans,
623-
m, // int m,
624-
n, // int n,
625-
k, // int k,
626-
alpha, // float alpha,
627-
x_data, // const TX* x,
628-
ldx, // int stride_a,
629-
reinterpret_cast<const XPUType*>(y), // const TW* w,
630-
ldy, // int stride_b,
631-
0.0, // float beta,
632-
reinterpret_cast<XPUType*>(out), // TY* y,
633-
ldout, // int stride_c,
617+
xblas_fc_batch_api(xpu_ctx, // Context* ctx,
618+
batch_size, // int batch_size,
619+
trans_x, // bool x_trans,
620+
trans_y, // bool w_trans,
621+
m, // int m,
622+
n, // int n,
623+
k, // int k,
624+
alpha, // float alpha,
625+
x_data, // const TX* x,
626+
ldx, // int stride_a,
627+
y_data, // const TW* w,
628+
ldy, // int stride_b,
629+
0.0, // float beta,
630+
reinterpret_cast<XPUType*>(out), // TY* y,
631+
ldout, // int stride_c,
634632
max_x, // const float* x_maxptr,
635633
max_y); // const float* w_maxptr
636634
}
@@ -708,6 +706,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
708706
max_dout,
709707
nullptr);
710708
dx_a = y, dx_b = dout_new;
709+
dx_shape.is_x_need_broadcast = dout_shape.is_y_need_broadcast;
711710
// dy = T(dout) * T(x)
712711
dy_shape.InitFcInfo(dout_shape.bs,
713712
dout_shape.n,
@@ -719,6 +718,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
719718
nullptr,
720719
nullptr);
721720
dy_a = dout_new, dy_b = x;
721+
dy_shape.is_y_need_broadcast = dout_shape.is_x_need_broadcast;
722722
} else if (trans_x) {
723723
// dx = y * T(dout)
724724
dx_shape.InitFcInfo(dout_shape.bs,
@@ -731,6 +731,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
731731
max_dout,
732732
nullptr);
733733
dx_a = y, dx_b = dout_new;
734+
dx_shape.is_x_need_broadcast = dout_shape.is_y_need_broadcast;
734735
// dy = x * dout
735736
dy_shape.InitFcInfo(dout_shape.bs,
736737
dout_shape.k,
@@ -755,6 +756,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
755756
nullptr,
756757
nullptr);
757758
dx_a = dout_new, dx_b = y;
759+
dx_shape.is_y_need_broadcast = dout_shape.is_y_need_broadcast;
758760
// dy = T(dout) * x
759761
dy_shape.InitFcInfo(dout_shape.bs,
760762
dout_shape.n,
@@ -766,6 +768,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
766768
nullptr,
767769
nullptr);
768770
dy_a = dout_new, dy_b = x;
771+
dy_shape.is_y_need_broadcast = dout_shape.is_x_need_broadcast;
769772
} else {
770773
// dx = dout * T(y)
771774
dx_shape.InitFcInfo(dout_shape.bs,
@@ -778,6 +781,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
778781
nullptr,
779782
nullptr);
780783
dx_a = dout_new, dx_b = y;
784+
dx_shape.is_y_need_broadcast = dout_shape.is_y_need_broadcast;
781785
// dy = T(x) * dout
782786
dy_shape.InitFcInfo(dout_shape.bs,
783787
dout_shape.k,

test/xpu/test_matmul_v2_op_xpu.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,41 @@ def config(self):
316316
self.trans_x = True
317317
self.trans_y = False
318318

319+
class TestMatMulOp21(TestMatMulV2Op):
320+
"""
321+
case 21 : (x.ndim >= 3) && (y.ndim <= 2),
322+
trans_x is true
323+
"""
324+
325+
def config(self):
326+
self.x_shape = (10, 100, 4)
327+
self.y_shape = (100, 10)
328+
self.trans_x = True
329+
self.trans_y = False
330+
331+
class TestMatMulOp22(TestMatMulV2Op):
332+
"""
333+
case 22 : (x.ndim <= 2) && (y.ndim >= 3)
334+
"""
335+
336+
def config(self):
337+
self.x_shape = (10, 100)
338+
self.y_shape = (5, 100, 4)
339+
self.trans_x = False
340+
self.trans_y = False
341+
342+
class TestMatMulOp23(TestMatMulV2Op):
343+
"""
344+
case 23 : (x.ndim <= 2) && (y.ndim >= 3),
345+
trans_y is True
346+
"""
347+
348+
def config(self):
349+
self.x_shape = (10, 100)
350+
self.y_shape = (5, 4, 100)
351+
self.trans_x = False
352+
self.trans_y = True
353+
319354
@check_run_big_shape_test()
320355
class TestMatMulOpLargeShape1(TestMatMulV2Op):
321356
"""

0 commit comments

Comments
 (0)