@@ -79,6 +79,7 @@ struct XpuFcInfo {
7979 float * max_out;
8080 const float * bias;
8181 bool is_x_need_broadcast;
82+ bool is_y_need_broadcast;
8283 const float * scale_x;
8384 const float * scale_y;
8485 int scale_x_mode;
@@ -99,6 +100,7 @@ struct XpuFcInfo {
99100 max_out(nullptr ),
100101 bias(nullptr ),
101102 is_x_need_broadcast(false ),
103+ is_y_need_broadcast(false ),
102104 scale_x(nullptr ),
103105 scale_y(nullptr ),
104106 scale_x_mode(0 ),
@@ -157,41 +159,16 @@ static void GetFCInfo(const phi::DDim& x_dims,
157159 auto mat_dim_b = phi::funcs::CreateMatrixDescriptor (new_y_dims, 0 , trans_y);
158160
159161 if (x_dims.size () >= 3 && y_dims.size () <= 2 ) {
160- if (!trans_x) {
162+ if (!trans_x || mat_dim_a. batch_size_ == 1 ) {
161163 mat_dim_a.height_ *= mat_dim_a.batch_size_ ;
162164 mat_dim_a.batch_size_ = 0 ;
163165 } else {
164- mat_dim_b.batch_size_ = mat_dim_a.batch_size_ ;
165- mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_ ;
166+ info->is_y_need_broadcast = true ;
166167 }
167168 }
168169
169170 if (y_dims.size () >= 3 && x_dims.size () <= 2 ) {
170- PADDLE_ENFORCE_EQ (
171- mat_dim_b.trans_ ,
172- false ,
173- phi::errors::InvalidArgument (
174- " xpu not support this Shape in matmul_op xdims = %s ydims = %s "
175- " x_trans = %d y_trans = %d" ,
176- x_dims.to_str (),
177- y_dims.to_str (),
178- mat_dim_a.trans_ ,
179- mat_dim_b.trans_ ));
180- if (mat_dim_a.width_ == mat_dim_b.batch_size_ * mat_dim_b.height_ ) {
181- mat_dim_b.height_ *= mat_dim_b.batch_size_ ;
182- mat_dim_b.batch_size_ = 0 ;
183- } else {
184- info->is_x_need_broadcast = true ;
185- }
186- }
187-
188- if (mat_dim_a.width_ == mat_dim_b.height_ ) {
189- if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1 ) {
190- mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0 ;
191- }
192- if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0 ) {
193- mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0 ;
194- }
171+ info->is_x_need_broadcast = (mat_dim_b.batch_size_ > 1 );
195172 }
196173
197174 PADDLE_ENFORCE_EQ (mat_dim_a.width_ ,
@@ -204,6 +181,13 @@ static void GetFCInfo(const phi::DDim& x_dims,
204181 mat_dim_a.trans_ ,
205182 mat_dim_b.trans_ ));
206183
184+ if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1 ) {
185+ mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0 ;
186+ }
187+ if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0 ) {
188+ mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0 ;
189+ }
190+
207191 info->m = mat_dim_a.height_ ;
208192 info->n = mat_dim_b.width_ ;
209193 info->k = mat_dim_a.width_ ;
@@ -572,6 +556,7 @@ static void MatMulXPUFunction(
572556 float * max_y = fcinfo.max_y ;
573557 float * max_out = fcinfo.max_out ;
574558 bool is_x_need_broadcast = fcinfo.is_x_need_broadcast ;
559+ bool is_y_need_broadcast = fcinfo.is_y_need_broadcast ;
575560 const float * bias = fcinfo.bias ;
576561 const float * scale_x = fcinfo.scale_x ;
577562 const float * scale_y = fcinfo.scale_y ;
@@ -615,22 +600,35 @@ static void MatMulXPUFunction(
615600 PADDLE_ENFORCE_XDNN_SUCCESS (r, " broadcast" );
616601 x_data = x_broadcast_data;
617602 }
603+ const XPUType* y_data = reinterpret_cast <const XPUType*>(y);
604+ if (is_y_need_broadcast) {
605+ XPUType* y_broadcast_data = nullptr ;
606+ xpu::ctx_guard RAII_GUARD (xpu_ctx);
607+ y_broadcast_data = RAII_GUARD.alloc_l3_or_gm <XPUType>(batch_size * k * n);
608+ PADDLE_ENFORCE_XDNN_NOT_NULL (y_broadcast_data);
609+ std::vector<int > y_shape = {1 , k, n};
610+ std::vector<int > new_y_shape = {batch_size, k, n};
611+ int r = xpu::broadcast<XPUType>(
612+ xpu_ctx, y_data, y_broadcast_data, y_shape, new_y_shape);
613+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " broadcast" );
614+ y_data = y_broadcast_data;
615+ }
618616 // batch matmul
619- xblas_fc_batch_api (xpu_ctx, // Context* ctx,
620- batch_size, // int batch_size,
621- trans_x, // bool x_trans,
622- trans_y, // bool w_trans,
623- m, // int m,
624- n, // int n,
625- k, // int k,
626- alpha, // float alpha,
627- x_data, // const TX* x,
628- ldx, // int stride_a,
629- reinterpret_cast < const XPUType*>(y), // const TW* w,
630- ldy, // int stride_b,
631- 0.0 , // float beta,
632- reinterpret_cast <XPUType*>(out), // TY* y,
633- ldout, // int stride_c,
617+ xblas_fc_batch_api (xpu_ctx, // Context* ctx,
618+ batch_size, // int batch_size,
619+ trans_x, // bool x_trans,
620+ trans_y, // bool w_trans,
621+ m, // int m,
622+ n, // int n,
623+ k, // int k,
624+ alpha, // float alpha,
625+ x_data, // const TX* x,
626+ ldx, // int stride_a,
627+ y_data, // const TW* w,
628+ ldy, // int stride_b,
629+ 0.0 , // float beta,
630+ reinterpret_cast <XPUType*>(out), // TY* y,
631+ ldout, // int stride_c,
634632 max_x, // const float* x_maxptr,
635633 max_y); // const float* w_maxptr
636634 }
@@ -708,6 +706,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
708706 max_dout,
709707 nullptr );
710708 dx_a = y, dx_b = dout_new;
709+ dx_shape.is_x_need_broadcast = dout_shape.is_y_need_broadcast ;
711710 // dy = T(dout) * T(x)
712711 dy_shape.InitFcInfo (dout_shape.bs ,
713712 dout_shape.n ,
@@ -719,6 +718,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
719718 nullptr ,
720719 nullptr );
721720 dy_a = dout_new, dy_b = x;
721+ dy_shape.is_y_need_broadcast = dout_shape.is_x_need_broadcast ;
722722 } else if (trans_x) {
723723 // dx = y * T(dout)
724724 dx_shape.InitFcInfo (dout_shape.bs ,
@@ -731,6 +731,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
731731 max_dout,
732732 nullptr );
733733 dx_a = y, dx_b = dout_new;
734+ dx_shape.is_x_need_broadcast = dout_shape.is_y_need_broadcast ;
734735 // dy = x * dout
735736 dy_shape.InitFcInfo (dout_shape.bs ,
736737 dout_shape.k ,
@@ -755,6 +756,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
755756 nullptr ,
756757 nullptr );
757758 dx_a = dout_new, dx_b = y;
759+ dx_shape.is_y_need_broadcast = dout_shape.is_y_need_broadcast ;
758760 // dy = T(dout) * x
759761 dy_shape.InitFcInfo (dout_shape.bs ,
760762 dout_shape.n ,
@@ -766,6 +768,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
766768 nullptr ,
767769 nullptr );
768770 dy_a = dout_new, dy_b = x;
771+ dy_shape.is_y_need_broadcast = dout_shape.is_x_need_broadcast ;
769772 } else {
770773 // dx = dout * T(y)
771774 dx_shape.InitFcInfo (dout_shape.bs ,
@@ -778,6 +781,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx,
778781 nullptr ,
779782 nullptr );
780783 dx_a = dout_new, dx_b = y;
784+ dx_shape.is_y_need_broadcast = dout_shape.is_y_need_broadcast ;
781785 // dy = T(x) * dout
782786 dy_shape.InitFcInfo (dout_shape.bs ,
783787 dout_shape.k ,
0 commit comments