Skip to content

Commit d3668a7

Browse files
authored
[XPU] support xblas::fc_fusion for fc_xpu(bfp16) (#69942)
1 parent 973dadd commit d3668a7

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed

paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,17 @@
1616
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1717
#include "paddle/phi/core/kernel_registry.h"
1818

19+
#ifdef PADDLE_WITH_XPU_XRE5
20+
#include "xblas/cublasLt.h"
21+
namespace xblas = baidu::xpu::xblas;
22+
#endif
23+
1924
namespace phi {
2025
namespace fusion {
2126

27+
using XPUTypeFP16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
28+
using XPUTypeBF16 = typename XPUTypeTrait<phi::dtype::bfloat16>::Type;
29+
2230
template <typename T_X,
2331
typename T_W,
2432
typename T_OUT,
@@ -77,6 +85,131 @@ void FcXPUKernelImpl(const Context& ctx,
7785
} else if (act_type == xpu::Activation_t::HARD_SIGMOID) {
7886
act.hard_sigmoid_slope = act_alpha;
7987
}
88+
// only for xpu3
89+
#ifdef PADDLE_WITH_XPU_XRE5
90+
if constexpr (std::is_same<T_X, bfloat16>::value &&
91+
std::is_same<T_W, bfloat16>::value &&
92+
std::is_same<T_OUT, bfloat16>::value) {
93+
// use xte to speedup bfloat16 calc
94+
// whether to enable this feature requires a trade-off between performance
95+
// precision
96+
if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") != nullptr) {
97+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
98+
const int MAXPTR_N = ctx.x_context()->max_ptr_size();
99+
int x_len = m * k;
100+
XPUTypeFP16* x_data_fp16 = nullptr;
101+
x_data_fp16 = RAII_GUARD.alloc_l3_or_gm<XPUTypeFP16>(x_len);
102+
PADDLE_ENFORCE_XDNN_NOT_NULL(x_data_fp16);
103+
int w_len = k * n;
104+
XPUTypeFP16* w_data_fp16 = nullptr;
105+
w_data_fp16 = RAII_GUARD.alloc_l3_or_gm<XPUTypeFP16>(w_len);
106+
PADDLE_ENFORCE_XDNN_NOT_NULL(w_data_fp16);
107+
108+
float* xte_scale_x = nullptr;
109+
float* xte_scale_w = nullptr;
110+
xte_scale_x = RAII_GUARD.alloc_l3_or_gm<float>(1);
111+
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_scale_x);
112+
xte_scale_w = RAII_GUARD.alloc_l3_or_gm<float>(1);
113+
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_scale_w);
114+
115+
float* xte_x_maxptr = nullptr;
116+
float* xte_w_maxptr = nullptr;
117+
if (x_max_data == nullptr) {
118+
xte_x_maxptr = RAII_GUARD.alloc_l3_or_gm<float>(MAXPTR_N);
119+
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_x_maxptr);
120+
int r = xpu::findmax(ctx.x_context(), x_data, xte_x_maxptr, x_len);
121+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_findmax");
122+
r = xpu::cast_te(ctx.x_context(),
123+
x_data,
124+
xte_x_maxptr,
125+
x_data_fp16,
126+
xte_scale_x,
127+
x_len);
128+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
129+
} else {
130+
int r = xpu::cast_te(ctx.x_context(),
131+
x_data,
132+
x_max_data,
133+
x_data_fp16,
134+
xte_scale_x,
135+
x_len);
136+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
137+
}
138+
if (w_max_data == nullptr) {
139+
xte_w_maxptr = RAII_GUARD.alloc_l3_or_gm<float>(MAXPTR_N);
140+
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_w_maxptr);
141+
int r = xpu::findmax(ctx.x_context(), w_data, xte_w_maxptr, w_len);
142+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_findmax");
143+
r = xpu::cast_te(ctx.x_context(),
144+
w_data,
145+
xte_w_maxptr,
146+
w_data_fp16,
147+
xte_scale_w,
148+
w_len);
149+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
150+
} else {
151+
int r = xpu::cast_te(ctx.x_context(),
152+
w_data,
153+
w_max_data,
154+
w_data_fp16,
155+
xte_scale_w,
156+
w_len);
157+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
158+
}
159+
int r =
160+
xblas::fc_fusion<XPUTypeFP16, XPUTypeFP16, XPUTypeBF16, XPUTypeFP16>(
161+
ctx.x_context(),
162+
x_data_fp16,
163+
w_data_fp16,
164+
out_data,
165+
m,
166+
n,
167+
k,
168+
transpose_x,
169+
true,
170+
x_max_data ? x_max_data : xte_x_maxptr,
171+
w_max_data ? w_max_data : xte_w_maxptr,
172+
out_max_data,
173+
transpose_x ? m : k,
174+
k,
175+
n,
176+
alpha,
177+
beta,
178+
bias_data,
179+
act,
180+
xte_scale_x,
181+
xte_scale_w);
182+
183+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion");
184+
}
185+
}
186+
if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") == nullptr) {
187+
int r = xpu::
188+
fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
189+
ctx.x_context(), // ctx
190+
x_data, // x
191+
w_data, // w
192+
out_data, // y
193+
m, // m
194+
n, // n
195+
k, // k
196+
transpose_x, // x_trans
197+
true, // w_trans
198+
x_max_data, // x_maxptr
199+
w_max_data, // w_maxptr
200+
out_max_data, // y_maxptr
201+
transpose_x ? m : k, // ldx
202+
k, // ldw
203+
n, // ldy
204+
alpha, // alpha
205+
beta, // beta
206+
bias_data, // bias
207+
act, // act
208+
scale_max_data); // scale
209+
210+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
211+
}
212+
#else
80213
int r =
81214
xpu::fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
82215
ctx.x_context(), // ctx
@@ -101,6 +234,7 @@ void FcXPUKernelImpl(const Context& ctx,
101234
scale_max_data); // scale
102235

103236
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
237+
#endif
104238
}
105239

106240
#define FC_XPU_KERNEL_IMPL(x_dtype_, w_dtype_, out_dtype_, gemm_dtype_) \

0 commit comments

Comments
 (0)