Skip to content

Commit f1b2aec

Browse files
btbtyler09claude
andcommitted
Fix GPTQ ROCm type conversion bug causing gibberish output
- Fix double type conversion bug in q_gemm.cu affecting all GPTQ models with tensor parallelism on ROCm - Move half2 res2 declaration inside loop with proper zero initialization - Remove problematic __half_as_ushort/__ushort_as_half conversions - Fix false Triton flash attention warning for models with sliding window when VLLM_USE_TRITON_FLASH_ATTN=0 - Changes match upstream PR vllm-project#17583 This fixes silent data corruption that was causing GPTQ models to produce gibberish output on ROCm with tensor parallelism. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent fa6497e commit f1b2aec

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

csrc/quantization/gptq/q_gemm.cu

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,6 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
12231223
int k = 0;
12241224
int z_w = w / 8;
12251225
int z_mod = (w % 8) * 4;
1226-
half2 res2;
12271226
half res[BLOCK_M_SIZE_MAX] = {};
12281227

12291228
unsigned int tmp;
@@ -1248,12 +1247,7 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
12481247
zeros_tmp[tmp_k] = zero;
12491248
}
12501249
for (int m = 0; m < b_end; m++) {
1251-
#ifndef USE_ROCM
1252-
res2 = {};
1253-
#else
1254-
res2.x = __half_as_ushort(__float2half(0));
1255-
res2.y = __half_as_ushort(__float2half(0));
1256-
#endif
1250+
half2 res2{};
12571251
res2 = __hfma2(
12581252
__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]),
12591253
blockvec[m][k + 0], res2);
@@ -1266,12 +1260,7 @@ __global__ void gemm_half_q_half_alt_4bit_kernel(
12661260
res2 = __hfma2(
12671261
__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]),
12681262
blockvec[m][k + 3], res2);
1269-
#ifndef USE_ROCM
12701263
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
1271-
#else
1272-
res[m] = __hadd(
1273-
res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
1274-
#endif
12751264
}
12761265
i += width;
12771266
k += 4;
@@ -1314,7 +1303,6 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
13141303
int k = 0;
13151304
int z_w = w / 4;
13161305
int z_mod = (w % 4) * 8;
1317-
half2 res2;
13181306
half res[BLOCK_M_SIZE_MAX] = {};
13191307

13201308
unsigned int tmp;
@@ -1339,12 +1327,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
13391327
zeros_tmp[tmp_k] = zero;
13401328
}
13411329
for (int m = 0; m < b_end; m++) {
1342-
#ifndef USE_ROCM
1343-
res2 = {};
1344-
#else
1345-
res2.x = __half_as_ushort(__float2half(0));
1346-
res2.y = __half_as_ushort(__float2half(0));
1347-
#endif
1330+
half2 res2{};
13481331
half2 v12 = __halves2half2(__int2half_rn(tmp & 0xFF),
13491332
__int2half_rn((tmp >> 8) & 0xFF));
13501333
res2 = __hfma2(__hfma2(v12, scales_tmp[0], zeros_tmp[0]),
@@ -1353,12 +1336,7 @@ __global__ void gemm_half_q_half_alt_8bit_kernel(
13531336
__int2half_rn((tmp >> 24) & 0xFF));
13541337
res2 = __hfma2(__hfma2(v34, scales_tmp[1], zeros_tmp[1]),
13551338
blockvec[m][k + 1], res2);
1356-
#ifndef USE_ROCM
13571339
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
1358-
#else
1359-
res[m] = __hadd(
1360-
res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
1361-
#endif
13621340
}
13631341
i += width;
13641342
k += 2;

vllm/platforms/rocm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def on_mi3xx() -> bool:
111111
return any(arch in GPU_ARCH for arch in ["gfx942", "gfx950"])
112112

113113

114+
@cache
115+
def on_mi100() -> bool:
116+
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
117+
return any(arch in GPU_ARCH for arch in ["gfx900", "gfx902", "gfx906", "gfx908"])
118+
119+
114120
@cache
115121
def on_gfx9() -> bool:
116122
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
@@ -328,6 +334,10 @@ def verify_model_arch(cls, model_arch: str) -> None:
328334

329335
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
330336
msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
337+
# Only show Triton-related warnings if Triton is actually being used
338+
if "Triton flash attention" in msg and not envs.VLLM_USE_TRITON_FLASH_ATTN:
339+
# Skip warning since Triton is not being used
340+
return
331341
logger.warning(
332342
"Model architecture '%s' is partially "
333343
"supported by ROCm: %s", model_arch, msg)

0 commit comments

Comments
 (0)