Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions paddle/phi/kernels/fusion/xpu/fc_xpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"

#ifdef PADDLE_WITH_XPU_XRE5
#include "xblas/cublasLt.h"
namespace xblas = baidu::xpu::xblas;
#endif

namespace phi {
namespace fusion {

using XPUTypeFP16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
using XPUTypeBF16 = typename XPUTypeTrait<phi::dtype::bfloat16>::Type;

template <typename T_X,
typename T_W,
typename T_OUT,
Expand Down Expand Up @@ -77,6 +85,131 @@ void FcXPUKernelImpl(const Context& ctx,
} else if (act_type == xpu::Activation_t::HARD_SIGMOID) {
act.hard_sigmoid_slope = act_alpha;
}
// only for xpu3
#ifdef PADDLE_WITH_XPU_XRE5
if constexpr (std::is_same<T_X, bfloat16>::value &&
std::is_same<T_W, bfloat16>::value &&
std::is_same<T_OUT, bfloat16>::value) {
// use xte to speedup bfloat16 calc
// whether to enable this feature requires a trade-off between performance
// precision
if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") != nullptr) {
xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int MAXPTR_N = ctx.x_context()->max_ptr_size();
int x_len = m * k;
XPUTypeFP16* x_data_fp16 = nullptr;
x_data_fp16 = RAII_GUARD.alloc_l3_or_gm<XPUTypeFP16>(x_len);
PADDLE_ENFORCE_XDNN_NOT_NULL(x_data_fp16);
int w_len = k * n;
XPUTypeFP16* w_data_fp16 = nullptr;
w_data_fp16 = RAII_GUARD.alloc_l3_or_gm<XPUTypeFP16>(w_len);
PADDLE_ENFORCE_XDNN_NOT_NULL(w_data_fp16);

float* xte_scale_x = nullptr;
float* xte_scale_w = nullptr;
xte_scale_x = RAII_GUARD.alloc_l3_or_gm<float>(1);
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_scale_x);
xte_scale_w = RAII_GUARD.alloc_l3_or_gm<float>(1);
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_scale_w);

float* xte_x_maxptr = nullptr;
float* xte_w_maxptr = nullptr;
if (x_max_data == nullptr) {
xte_x_maxptr = RAII_GUARD.alloc_l3_or_gm<float>(MAXPTR_N);
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_x_maxptr);
int r = xpu::findmax(ctx.x_context(), x_data, xte_x_maxptr, x_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_findmax");
r = xpu::cast_te(ctx.x_context(),
x_data,
xte_x_maxptr,
x_data_fp16,
xte_scale_x,
x_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
} else {
int r = xpu::cast_te(ctx.x_context(),
x_data,
x_max_data,
x_data_fp16,
xte_scale_x,
x_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
}
if (w_max_data == nullptr) {
xte_w_maxptr = RAII_GUARD.alloc_l3_or_gm<float>(MAXPTR_N);
PADDLE_ENFORCE_XDNN_NOT_NULL(xte_w_maxptr);
int r = xpu::findmax(ctx.x_context(), w_data, xte_w_maxptr, w_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_findmax");
r = xpu::cast_te(ctx.x_context(),
w_data,
xte_w_maxptr,
w_data_fp16,
xte_scale_w,
w_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
} else {
int r = xpu::cast_te(ctx.x_context(),
w_data,
w_max_data,
w_data_fp16,
xte_scale_w,
w_len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "xpu_cast_te");
}
int r =
xblas::fc_fusion<XPUTypeFP16, XPUTypeFP16, XPUTypeBF16, XPUTypeFP16>(
ctx.x_context(),
x_data_fp16,
w_data_fp16,
out_data,
m,
n,
k,
transpose_x,
true,
x_max_data ? x_max_data : xte_x_maxptr,
w_max_data ? w_max_data : xte_w_maxptr,
out_max_data,
transpose_x ? m : k,
k,
n,
alpha,
beta,
bias_data,
act,
xte_scale_x,
xte_scale_w);

PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_fusion");
}
}
if (std::getenv("XPU_PADDLE_FC_BFLOAT16_XTE") == nullptr) {
int r = xpu::
fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
ctx.x_context(), // ctx
x_data, // x
w_data, // w
out_data, // y
m, // m
n, // n
k, // k
transpose_x, // x_trans
true, // w_trans
x_max_data, // x_maxptr
w_max_data, // w_maxptr
out_max_data, // y_maxptr
transpose_x ? m : k, // ldx
k, // ldw
n, // ldy
alpha, // alpha
beta, // beta
bias_data, // bias
act, // act
scale_max_data); // scale

PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
}
#else
int r =
xpu::fc_fusion<XPUTypeX, XPUTypeW, XPUTypeOut, T_GEMM>( // TX/TW/TY/TGEMM
ctx.x_context(), // ctx
Expand All @@ -101,6 +234,7 @@ void FcXPUKernelImpl(const Context& ctx,
scale_max_data); // scale

PADDLE_ENFORCE_XDNN_SUCCESS(r, "fc_xpu");
#endif
}

#define FC_XPU_KERNEL_IMPL(x_dtype_, w_dtype_, out_dtype_, gemm_dtype_) \
Expand Down