1616#include " paddle/phi/backends/xpu/enforce_xpu.h"
1717#include " paddle/phi/core/kernel_registry.h"
1818
19+ #ifdef PADDLE_WITH_XPU_XRE5
20+ #include " xblas/cublasLt.h"
21+ namespace xblas = baidu::xpu::xblas;
22+ #endif
23+
1924namespace phi {
2025namespace fusion {
2126
27+ using XPUTypeFP16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
28+ using XPUTypeBF16 = typename XPUTypeTrait<phi::dtype::bfloat16>::Type;
29+
2230template <typename T_X,
2331 typename T_W,
2432 typename T_OUT,
@@ -77,6 +85,131 @@ void FcXPUKernelImpl(const Context& ctx,
7785 } else if (act_type == xpu::Activation_t::HARD_SIGMOID) {
7886 act.hard_sigmoid_slope = act_alpha;
7987 }
88+ // only for xpu3
89+ #ifdef PADDLE_WITH_XPU_XRE5
90+ if constexpr (std::is_same<T_X, bfloat16>::value &&
91+ std::is_same<T_W, bfloat16>::value &&
92+ std::is_same<T_OUT, bfloat16>::value) {
93+ // use xte to speedup bfloat16 calc
94+ // whether to enable this feature requires a trade-off between performance
95+ // precision
96+ if (std::getenv (" XPU_PADDLE_FC_BFLOAT16_XTE" ) != nullptr ) {
97+ xpu::ctx_guard RAII_GUARD (ctx.x_context ());
98+ const int MAXPTR_N = ctx.x_context ()->max_ptr_size ();
99+ int x_len = m * k;
100+ XPUTypeFP16* x_data_fp16 = nullptr ;
101+ x_data_fp16 = RAII_GUARD.alloc_l3_or_gm <XPUTypeFP16>(x_len);
102+ PADDLE_ENFORCE_XDNN_NOT_NULL (x_data_fp16);
103+ int w_len = k * n;
104+ XPUTypeFP16* w_data_fp16 = nullptr ;
105+ w_data_fp16 = RAII_GUARD.alloc_l3_or_gm <XPUTypeFP16>(w_len);
106+ PADDLE_ENFORCE_XDNN_NOT_NULL (w_data_fp16);
107+
108+ float * xte_scale_x = nullptr ;
109+ float * xte_scale_w = nullptr ;
110+ xte_scale_x = RAII_GUARD.alloc_l3_or_gm <float >(1 );
111+ PADDLE_ENFORCE_XDNN_NOT_NULL (xte_scale_x);
112+ xte_scale_w = RAII_GUARD.alloc_l3_or_gm <float >(1 );
113+ PADDLE_ENFORCE_XDNN_NOT_NULL (xte_scale_w);
114+
115+ float * xte_x_maxptr = nullptr ;
116+ float * xte_w_maxptr = nullptr ;
117+ if (x_max_data == nullptr ) {
118+ xte_x_maxptr = RAII_GUARD.alloc_l3_or_gm <float >(MAXPTR_N);
119+ PADDLE_ENFORCE_XDNN_NOT_NULL (xte_x_maxptr);
120+ int r = xpu::findmax (ctx.x_context (), x_data, xte_x_maxptr, x_len);
121+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xpu_findmax" );
122+ r = xpu::cast_te (ctx.x_context (),
123+ x_data,
124+ xte_x_maxptr,
125+ x_data_fp16,
126+ xte_scale_x,
127+ x_len);
128+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xpu_cast_te" );
129+ } else {
130+ int r = xpu::cast_te (ctx.x_context (),
131+ x_data,
132+ x_max_data,
133+ x_data_fp16,
134+ xte_scale_x,
135+ x_len);
136+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xpu_cast_te" );
137+ }
138+ if (w_max_data == nullptr ) {
139+ xte_w_maxptr = RAII_GUARD.alloc_l3_or_gm <float >(MAXPTR_N);
140+ PADDLE_ENFORCE_XDNN_NOT_NULL (xte_w_maxptr);
141+ int r = xpu::findmax (ctx.x_context (), w_data, xte_w_maxptr, w_len);
142+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xpu_findmax" );
143+ r = xpu::cast_te (ctx.x_context (),
144+ w_data,
145+ xte_w_maxptr,
146+ w_data_fp16,
147+ xte_scale_w,
148+ w_len);
149+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xpu_cast_te" );
150+ } else {
151+ int r = xpu::cast_te (ctx.x_context (),
152+ w_data,
153+ w_max_data,
154+ w_data_fp16,
155+ xte_scale_w,
156+ w_len);
157+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xpu_cast_te" );
158+ }
159+ int r =
160+ xblas::fc_fusion<XPUTypeFP16, XPUTypeFP16, XPUTypeBF16, XPUTypeFP16>(
161+ ctx.x_context (),
162+ x_data_fp16,
163+ w_data_fp16,
164+ out_data,
165+ m,
166+ n,
167+ k,
168+ transpose_x,
169+ true ,
170+ x_max_data ? x_max_data : xte_x_maxptr,
171+ w_max_data ? w_max_data : xte_w_maxptr,
172+ out_max_data,
173+ transpose_x ? m : k,
174+ k,
175+ n,
176+ alpha,
177+ beta,
178+ bias_data,
179+ act,
180+ xte_scale_x,
181+ xte_scale_w);
182+
183+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " xblas_fc_fusion" );
184+ }
185+ }
186+ if (std::getenv (" XPU_PADDLE_FC_BFLOAT16_XTE" ) == nullptr ) {
187+ int r = xpu::
188+ fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
189+ ctx.x_context (), // ctx
190+ x_data, // x
191+ w_data, // w
192+ out_data, // y
193+ m, // m
194+ n, // n
195+ k, // k
196+ transpose_x, // x_trans
197+ true , // w_trans
198+ x_max_data, // x_maxptr
199+ w_max_data, // w_maxptr
200+ out_max_data, // y_maxptr
201+ transpose_x ? m : k, // ldx
202+ k, // ldw
203+ n, // ldy
204+ alpha, // alpha
205+ beta, // beta
206+ bias_data, // bias
207+ act, // act
208+ scale_max_data); // scale
209+
210+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " fc_xpu" );
211+ }
212+ #else
80213 int r =
81214 xpu::fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
82215 ctx.x_context (), // ctx
@@ -101,6 +234,7 @@ void FcXPUKernelImpl(const Context& ctx,
101234 scale_max_data); // scale
102235
103236 PADDLE_ENFORCE_XDNN_SUCCESS (r, " fc_xpu" );
237+ #endif
104238}
105239
106240#define FC_XPU_KERNEL_IMPL (x_dtype_, w_dtype_, out_dtype_, gemm_dtype_ ) \
0 commit comments