File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
vllm/model_executor/layers/quantization/utils Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -366,7 +366,7 @@ def per_token_group_quant_fp8(
366
366
dtype : Optional [torch .dtype ] = None ,
367
367
column_major_scales : bool = False ,
368
368
out_q : Optional [torch .Tensor ] = None ,
369
- use_ue8m0 : bool = is_blackwell_deep_gemm_used () ,
369
+ use_ue8m0 : Optional [ bool ] = None ,
370
370
) -> tuple [torch .Tensor , torch .Tensor ]:
371
371
"""Function to perform per-token-group quantization on an input tensor `x`.
372
372
It converts the tensor values into signed float8 values and returns the
@@ -383,6 +383,10 @@ def per_token_group_quant_fp8(
383
383
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
384
384
scaling factor.
385
385
"""
386
+ # TODO(wentao): refactor this
387
+ # use_ue8m0 should be a global flag that could be set by user
388
+ if use_ue8m0 is None :
389
+ use_ue8m0 = is_blackwell_deep_gemm_used ()
386
390
dtype = current_platform .fp8_dtype () if dtype is None else dtype
387
391
assert (x .shape [- 1 ] % group_size == 0 ), (
388
392
f"the last dimension of `x` { x .shape [- 1 ]} must be divisible "
You can’t perform that action at this time.
0 commit comments