@@ -932,14 +932,14 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
932
932
for (int k_0_1 = 0 ; k_0_1 < 2 ; ++k_0_1) {
933
933
{
934
934
unsigned int addr;
935
- __asm__ __volatile__ (
935
+ asm volatile (
936
936
" { .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n "
937
937
: " =r" (addr)
938
938
: " l" ((void *)((&(A_shared[(k_0_1 * 16 )])) + (((((int )threadIdx .x ) & 15 ) * 40 ) + ((((int )threadIdx .x ) >> 4 ) * 8 ))))
939
939
);
940
940
941
941
942
- __asm__ __volatile__ (
942
+ asm volatile (
943
943
" ldmatrix.sync.aligned.m8n8.x4.shared.b16"
944
944
" {%0, %1, %2, %3}, [%4];\n "
945
945
: " =r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " =r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " =r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " =r" (((unsigned *)(A_shared_warp + 0 ))[3 ])
@@ -950,12 +950,12 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
950
950
for (int ax1_0 = 0 ; ax1_0 < N / 32 ; ++ax1_0) {
951
951
{
952
952
unsigned int addr;
953
- __asm__ __volatile__ (
953
+ asm volatile (
954
954
" { .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n "
955
955
: " =r" (addr)
956
956
: " l" ((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128 )) + (((int )threadIdx .y ) * (N / 2 ))) + (ax1_0 * 16 ))])) + (((((int )threadIdx .x ) & 15 ) * (N + 8 )) + ((((int )threadIdx .x ) >> 4 ) * 8 ))))
957
957
);
958
- __asm__ __volatile__ (
958
+ asm volatile (
959
959
" ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
960
960
" {%0, %1, %2, %3}, [%4];\n "
961
961
: " =r" (((unsigned *)(B_shared_warp + (ax1_0 * 8 )))[0 ]), " =r" (((unsigned *)(B_shared_warp + (ax1_0 * 8 )))[1 ]), " =r" (((unsigned *)(B_shared_warp + (ax1_0 * 8 )))[2 ]), " =r" (((unsigned *)(B_shared_warp + (ax1_0 * 8 )))[3 ])
@@ -966,47 +966,47 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
966
966
for (int j_0_4 = 0 ; j_0_4 < N / 32 ; ++j_0_4) {
967
967
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
968
968
{
969
- __asm__ __volatile__ (
969
+ asm volatile (
970
970
" mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
971
971
" {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
972
972
: " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
973
973
: " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
974
974
}
975
975
976
976
{
977
- __asm__ __volatile__ (
977
+ asm volatile (
978
978
" mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
979
979
" {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
980
980
: " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
981
981
: " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
982
982
}
983
983
984
984
{
985
- __asm__ __volatile__ (
985
+ asm volatile (
986
986
" mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
987
987
" {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
988
988
: " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
989
989
: " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
990
990
}
991
991
992
992
{
993
- __asm__ __volatile__ (
993
+ asm volatile (
994
994
" mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
995
995
" {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n "
996
996
: " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
997
997
: " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ]));
998
998
}
999
999
#else
1000
1000
{
1001
- __asm__ __volatile__ (
1001
+ asm volatile (
1002
1002
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
1003
1003
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n "
1004
1004
: " =f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " =f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ])
1005
1005
: " r" (((unsigned *)(A_shared_warp + 0 ))[0 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[1 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[2 ]), " r" (((unsigned *)(A_shared_warp + 0 ))[3 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[0 ]), " r" (((unsigned *)(B_shared_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[0 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[1 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[2 ]), " f" (((float *)(C_warp + (j_0_4 * 8 )))[3 ]));
1006
1006
}
1007
1007
1008
1008
{
1009
- __asm__ __volatile__ (
1009
+ asm volatile (
1010
1010
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
1011
1011
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n "
1012
1012
: " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[0 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[1 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[2 ]), " =f" (((float *)(C_warp + ((j_0_4 * 8 ) + 4 )))[3 ])
0 commit comments