Skip to content

Commit 13348c5

Browse files
authored
[ARM CPU] hgemm optimized for gqa (#23107)
### Description Add fp16 kernels for GQA matmul on ARM CPU. The kernels are mlas hgemm for C = alpha * A x B' + beta * C ### Motivation and Context Add fp16 support for GQA, speed up the operator and reduce memory usage. __Token Generation__ | | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) | |---------------------------------|--------------------|--------------------|--------------| | M:1/N:4096/K:4096 | 251551 | 1775905 | 85.84 | | M:1/N:11008/K:4096 | 892507 | 4649145 | 80.80 | | M:1/N:4096/K:11008 | 866860 | 3240015 | 73.25 | | M:1/N:11008/K:11008 | 2631615 |8783877 | 70.04 | __Prompting__ | | HGEMM Runtime (ns) | SGEMM Runtime (ns) | Speed-up (%) | |---------------------------------|--------------------|--------------------|--------------| | M:1024/N:4096/K:4096 | 90508701 | 111283029 | 18.67 | | M:2048/N:4096/K:4096 | 181307522 | 240211107 | 24.52 | | M:1024/N:11008/K:4096 | 241120234 | 307707933 | 21.64 | | M:2048/N:11008/K:4096 | 481091232 | 648921367 | 25.86 | | M:1024/N:4096/K:11008 | 241736343 | 310129880 | 22.05 | | M:2048/N:4096/K:11008 | 480456703 | 644814999 | 25.49 | | M:1024/N:11008/K:11008 | 642121440 | 847925766 | 24.27 | | M:2048/N:11008/K:11008 | 1276097154 | 1731314509 | 26.29
1 parent c89a798 commit 13348c5

File tree

13 files changed

+2594
-34
lines changed

13 files changed

+2594
-34
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ function(setup_mlas_source_for_windows)
9595
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
9696
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
9797
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
98+
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
99+
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
98100
)
99101

100102
set(mlas_platform_preprocess_srcs
@@ -374,6 +376,7 @@ else()
374376
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
375377
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
376378
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
379+
${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp
377380
)
378381
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
379382
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
@@ -394,6 +397,7 @@ else()
394397
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
395398
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
396399
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
400+
${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp
397401
)
398402
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
399403
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -406,6 +410,7 @@ else()
406410
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
407411
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
408412
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
413+
set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
409414
endif()
410415

411416
if(ONNXRUNTIME_MLAS_MULTI_ARCH)

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class GQAAttentionBase {
7575
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);
7676

7777
// Compute the attention score.
78+
// TODO(fajin): type depends on kernel supportability
7879
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
7980
auto attention_probs = allocator->Alloc(bytes);
8081
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
@@ -198,6 +199,11 @@ class GQAAttentionBase {
198199
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
199200
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
200201
output, static_cast<int>(present_buffer_sequence_length), nullptr);
202+
// TODO(fajin): update later
203+
// } else if (MlasHGemmSupported(CblasNoTrans, CblasTrans)) {
204+
// MlasGemm(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size,
205+
// q, static_cast<int>(head_size), k, static_cast<int>(head_size), output,
206+
// static_cast<int>(present_buffer_sequence_length), alpha, 0.0f /*beta*/, nullptr);
201207
} else {
202208
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
203209
auto q_k_fp32 = allocator->Alloc(bytes);

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1458,7 +1458,107 @@ MlasRotaryEmbedOneRow(
14581458
T* output
14591459
);
14601460

1461-
/**
1461+
/**
1462+
* @brief Supply matrices data information to half precision gemm functions
1463+
*/
1464+
struct MLAS_HGEMM_DATA_PARAMS {
1465+
const MLAS_FP16* A; /**< Supplies the address of matrix A */
1466+
size_t lda; /**< Supplies the first dimension of matrix A. */
1467+
const MLAS_FP16* B; /**< Supplies the address of matrix B */
1468+
size_t ldb; /**< Supplies the first dimension of matrix B. */
1469+
MLAS_FP16* C; /**< Supplies the address of matrix C */
1470+
size_t ldc; /**< Supplies the first dimension of matrix C. */
1471+
uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */
1472+
uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */
1473+
};
1474+
1475+
/**
1476+
* @brief Check whether current CPU supports half precision gemm.
1477+
*/
1478+
bool
1479+
MLASCALL
1480+
MlasHGemmSupported(
1481+
CBLAS_TRANSPOSE TransA,
1482+
CBLAS_TRANSPOSE TransB
1483+
);
1484+
1485+
/**
1486+
* @brief Batched half precision matrix/matrix multiply operation (HGEMM)
1487+
*
1488+
* @param TransA Supplies the transpose operation for matrix A.
1489+
* @param TransB Supplies the transpose operation for matrix B.
1490+
* @param M Supplies the number of rows of matrix A and matrix C.
1491+
* @param N Supplies the number of columns of matrix B and matrix C.
1492+
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
1493+
* @param Data A array of matrices data parameters
1494+
* @param BatchSize Supplies number of multiplications in this batch
1495+
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the
1496+
base library threading support should be used.
1497+
*/
1498+
void
1499+
MLASCALL
1500+
MlasGemmBatch(
1501+
CBLAS_TRANSPOSE TransA,
1502+
CBLAS_TRANSPOSE TransB,
1503+
size_t M,
1504+
size_t N,
1505+
size_t K,
1506+
const MLAS_HGEMM_DATA_PARAMS* Data,
1507+
size_t BatchSize,
1508+
MLAS_THREADPOOL* ThreadPool
1509+
);
1510+
1511+
/**
1512+
* @brief half precision matrix/matrix multiply operation (HGEMM)
1513+
* C = alpha * op(A) * op(B) + beta * C
1514+
*
1515+
* @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans.
1516+
* @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans.
1517+
* @param M Supplies the number of rows of matrix A and matrix C.
1518+
* @param N Supplies the number of columns of matrix B and matrix C.
1519+
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
1520+
* @param A Supplies the address of matrix A
1521+
* @param lda Supplies the first dimension of matrix A.
1522+
* @param B Supplies the address of matrix B
1523+
* @param ldb Supplies the first dimension of matrix B.
1524+
* @param C Supplies the address of matrix C
1525+
* @param ldc Supplies the first dimension of matrix C.
1526+
* @param alpha Supplies the scalar alpha multiplier (see GEMM definition)
1527+
* @param beta Supplies the scalar beta multiplier (see GEMM definition)
1528+
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support
1529+
* should be used.
1530+
*/
1531+
inline
1532+
void
1533+
MlasGemm(
1534+
CBLAS_TRANSPOSE TransA,
1535+
CBLAS_TRANSPOSE TransB,
1536+
size_t M,
1537+
size_t N,
1538+
size_t K,
1539+
const MLAS_FP16* A,
1540+
size_t lda,
1541+
const MLAS_FP16* B,
1542+
size_t ldb,
1543+
MLAS_FP16* C,
1544+
size_t ldc,
1545+
uint16_t alpha,
1546+
uint16_t beta,
1547+
MLAS_THREADPOOL* ThreadPool
1548+
) {
1549+
MLAS_HGEMM_DATA_PARAMS Data;
1550+
Data.A = A;
1551+
Data.lda = lda;
1552+
Data.B = B;
1553+
Data.ldb = ldb;
1554+
Data.C = C;
1555+
Data.ldc = ldc;
1556+
Data.alpha = alpha;
1557+
Data.beta = beta;
1558+
MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool);
1559+
}
1560+
1561+
/**
14621562
* @brief Whether current CPU supports FP16 acceleration.
14631563
*/
14641564
bool MLASCALL

onnxruntime/core/mlas/lib/fp16_common.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,4 +349,103 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT
349349
return vbsl_f16(select, ones, zeros);
350350
}
351351

352+
MLAS_FORCEINLINE
353+
void
354+
Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3,
355+
MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7)
356+
{
357+
// |v00|v01|v02|v03|v04|v05|v06|v07|
358+
// |v10|v11|v12|v13|v14|v15|v16|v17|
359+
// |v20|v21|v22|v23|v24|v25|v26|v27|
360+
// |v30|v31|v32|v33|v34|v35|v36|v37|
361+
// |v40|v41|v42|v43|v44|v45|v46|v47|
362+
// |v50|v51|v52|v53|v54|v55|v56|v57|
363+
// |v60|v61|v62|v63|v64|v65|v66|v67|
364+
// |v70|v71|v72|v73|v74|v75|v76|v77|
365+
float16x8x2_t t01 = vtrnq_f16(v0, v1);
366+
float16x8x2_t t23 = vtrnq_f16(v2, v3);
367+
float16x8x2_t t45 = vtrnq_f16(v4, v5);
368+
float16x8x2_t t67 = vtrnq_f16(v6, v7);
369+
// |v00|v10|v02|v12|v04|v14|v06|v16|
370+
// |v01|v11|v03|v13|v05|v15|v07|v17|
371+
// |v20|v30|v22|v32|v24|v34|v26|v36|
372+
// |v21|v31|v23|v33|v25|v35|v27|v37|
373+
// |v40|v50|v42|v52|v44|v54|v46|v56|
374+
// |v41|v51|v43|v53|v45|v55|v47|v57|
375+
// |v60|v70|v62|v72|v64|v74|v66|v76|
376+
// |v61|v71|v63|v73|v65|v75|v67|v77|
377+
float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]));
378+
float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]));
379+
float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0]));
380+
float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1]));
381+
// |v00|v10|v20|v30|v04|v14|v24|v34|
382+
// |v01|v11|v21|v31|v05|v15|v25|v35|
383+
// |v02|v12|v22|v32|v06|v16|v26|v36|
384+
// |v03|v13|v23|v33|v07|v17|v27|v37|
385+
// |v40|v50|v60|v70|v44|v54|v64|v74|
386+
// |v41|v51|v61|v71|v45|v55|v65|v75|
387+
// |v42|v52|v62|v72|v46|v56|v66|v76|
388+
// |v43|v53|v63|v73|v47|v57|v67|v77|
389+
v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
390+
v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0])));
391+
v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
392+
v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1])));
393+
v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
394+
v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0])));
395+
v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
396+
v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1])));
397+
// |v00|v10|v20|v30|v40|v50|v60|v70|
398+
// |v01|v11|v21|v31|v41|v51|v61|v71|
399+
// |v02|v12|v22|v32|v42|v52|v62|v72|
400+
// |v03|v13|v23|v33|v43|v53|v63|v73|
401+
// |v04|v14|v24|v34|v44|v54|v64|v74|
402+
// |v05|v15|v25|v35|v45|v55|v65|v75|
403+
// |v06|v16|v26|v36|v46|v56|v66|v76|
404+
// |v07|v17|v27|v37|v47|v57|v67|v77|
405+
}
406+
407+
MLAS_FORCEINLINE
408+
void
409+
Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3)
410+
{
411+
// |v00|v01|v02|v03|v04|v05|v06|v07|
412+
// |v10|v11|v12|v13|v14|v15|v16|v17|
413+
// |v20|v21|v22|v23|v24|v25|v26|v27|
414+
// |v30|v31|v32|v33|v34|v35|v36|v37|
415+
// =>
416+
// |v00|v10|v20|v30|v04|v14|v24|v34|
417+
// |v01|v11|v21|v31|v05|v15|v25|v35|
418+
// |v02|v12|v22|v32|v06|v16|v26|v36|
419+
// |v03|v13|v23|v33|v07|v17|v27|v37|
420+
float16x8x2_t t01 = vtrnq_f16(v0, v1);
421+
float16x8x2_t t23 = vtrnq_f16(v2, v3);
422+
423+
v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
424+
v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])));
425+
v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
426+
v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])));
427+
}
428+
429+
MLAS_FORCEINLINE
430+
void
431+
Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3)
432+
{
433+
// |v00|v01|v02|v03|
434+
// |v10|v11|v12|v13|
435+
// |v20|v21|v22|v23|
436+
// |v30|v31|v32|v33|
437+
// =>
438+
// |v00|v10|v20|v30|
439+
// |v01|v11|v21|v31|
440+
// |v02|v12|v22|v32|
441+
// |v03|v13|v23|v33|
442+
float16x4x2_t t01 = vtrn_f16(v0, v1);
443+
float16x4x2_t t23 = vtrn_f16(v2, v3);
444+
445+
v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
446+
v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
447+
v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0])));
448+
v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1])));
449+
}
450+
352451
#endif // fp16 vector intrinsic supported

0 commit comments

Comments
 (0)