Skip to content
This repository was archived by the owner on May 11, 2025. It is now read-only.

Commit bad253e

Browse files
committed
Windows support
1 parent 8907d18 commit bad253e

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

awq_ext/quantization/gemm_cuda_gen.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -932,14 +932,14 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
932932
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
933933
{
934934
unsigned int addr;
935-
__asm__ __volatile__(
935+
asm volatile(
936936
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
937937
: "=r"(addr)
938938
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
939939
);
940940

941941

942-
__asm__ __volatile__(
942+
asm volatile(
943943
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
944944
"{%0, %1, %2, %3}, [%4];\n"
945945
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
@@ -950,12 +950,12 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
950950
for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
951951
{
952952
unsigned int addr;
953-
__asm__ __volatile__(
953+
asm volatile(
954954
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
955955
: "=r"(addr)
956956
: "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
957957
);
958-
__asm__ __volatile__(
958+
asm volatile(
959959
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
960960
"{%0, %1, %2, %3}, [%4];\n"
961961
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
@@ -966,47 +966,47 @@ __global__ void __launch_bounds__(64) group_gemm_forward_4bit_cuda_m16nXk32(
966966
for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
967967
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
968968
{
969-
__asm__ __volatile__(
969+
asm volatile(
970970
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
971971
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
972972
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
973973
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
974974
}
975975

976976
{
977-
__asm__ __volatile__(
977+
asm volatile(
978978
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
979979
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
980980
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
981981
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
982982
}
983983

984984
{
985-
__asm__ __volatile__(
985+
asm volatile(
986986
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
987987
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
988988
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
989989
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
990990
}
991991

992992
{
993-
__asm__ __volatile__(
993+
asm volatile(
994994
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
995995
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
996996
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
997997
: "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
998998
}
999999
#else
10001000
{
1001-
__asm__ __volatile__(
1001+
asm volatile(
10021002
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
10031003
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
10041004
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
10051005
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
10061006
}
10071007

10081008
{
1009-
__asm__ __volatile__(
1009+
asm volatile(
10101010
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
10111011
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
10121012
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])

0 commit comments

Comments
 (0)