Skip to content

Commit 633f6e8

Browse files
authored
[Bug] Fix DeepGemm Init Error (vllm-project#21554)
Signed-off-by: yewentao256 <[email protected]>
1 parent b57296b commit 633f6e8

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

vllm/model_executor/layers/quantization/utils/fp8_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def per_token_group_quant_fp8(
366366
dtype: Optional[torch.dtype] = None,
367367
column_major_scales: bool = False,
368368
out_q: Optional[torch.Tensor] = None,
369-
use_ue8m0: bool = is_blackwell_deep_gemm_used(),
369+
use_ue8m0: Optional[bool] = None,
370370
) -> tuple[torch.Tensor, torch.Tensor]:
371371
"""Function to perform per-token-group quantization on an input tensor `x`.
372372
It converts the tensor values into signed float8 values and returns the
@@ -383,6 +383,10 @@ def per_token_group_quant_fp8(
383383
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
384384
scaling factor.
385385
"""
386+
# TODO(wentao): refactor this
387+
# use_ue8m0 should be a global flag that could be set by user
388+
if use_ue8m0 is None:
389+
use_ue8m0 = is_blackwell_deep_gemm_used()
386390
dtype = current_platform.fp8_dtype() if dtype is None else dtype
387391
assert (x.shape[-1] % group_size == 0), (
388392
f"the last dimension of `x` {x.shape[-1]} must be divisible "

0 commit comments

Comments
 (0)