-
-
Notifications
You must be signed in to change notification settings - Fork 9.2k
[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm #7210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
# If rocm, use float8_e4m3fnuz | ||
if isinstance(current_platform, rocm.RocmPlatform): | ||
weight_as_int8 = layer.weight.view(torch.int8) | ||
weight_as_int8[weight_as_int8 == -128] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not expect NaN
(127
and 255
) in weights for inference, could flag an exception for fail safe.
Reset -0
to 0
looks good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0b10000000 is -0 in OCP, but is a NaN on NANOO, it needs this special handling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you pull out -128
as a named constant then? ROCM_FP8_NAN = -128
or whatever is accurate here
/ready |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please write restrictions of this approach (with scaled_mm
) in PR message body, such as: no output scaling is possible when needed, etc.
csrc/quantization/fp8/common.cu
Outdated
@@ -21,20 +28,26 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { | |||
return old; | |||
} | |||
|
|||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() | |||
#define FP8_E4M3_MAX std::numeric_limits<FP8_TYPE>::max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N.B. AMD uses up to 224
not 240
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this should be handled by torch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 it would be best to leave this up to torch if possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is bit too strong to ask torch to make the change, as our current HW does support up to 240, nothing blocks it is being used. here in inference stack to maintain Interop with OCP via manipulating scaling factors, we introduced scale-by-2, to assure that I prefer us stick to 224 (maps to OCP-MAX 448).
Not ask us to change real value of numeric_limits<c10::Float8_e4m3fnuz>::max()
, just ask to have a redef here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I see the justification now, that is fine with me. Please just leave a comment explaining the special definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N.B. AMD uses up to
224
not240
Fix me if I am wrong, as far as I know fp8 has binary format 0bSEEEEMMM, without support of +- infinite number, exponent bias = 2^(4-1)=8
Note NVIDIA uses 0b1111.111 to represent NaNs, so they use 0b1111.110 to represent fp8 max.
Hence, our fp8_max (max normal) = 0b01111 111 with bias 8, computed as:
fp8_max = (-1)^S * 2^{e - bias}*1.f_2 = 2^{15 - 8} * (1+0.875)
fp8_max = 240
with 224, I guess you ar.e computing with 0b01111 110, hence 224=2^{15 - 8} * (1+0.75).
That means 0b0111 111 used to represent special numbers. However in MLIR fp8 's nan is 0b10000000.
see https://github.com/jax-ml/ml_dtypes#float8_e4m3fnuz
also
https://github.com/pytorch/pytorch/blob/main/c10/util/Float8_e4m3fnuz.h
csrc/quantization/fp8/common.cu
Outdated
@@ -74,8 +87,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, | |||
// Finally, since cache[0] contains the maximum for this thread block, | |||
// atomically write the max to the target location | |||
if (threadIdx.x == 0) { | |||
atomicMaxFloat(scale, | |||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max()); | |||
atomicMaxFloat(scale, cache[0] / std::numeric_limits<FP8_TYPE>::max()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case AMD: note 224
vs 240
csrc/quantization/fp8/common.cu
Outdated
@@ -21,20 +28,26 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { | |||
return old; | |||
} | |||
|
|||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() | |||
#define FP8_E4M3_MAX std::numeric_limits<FP8_TYPE>::max() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 it would be best to leave this up to torch if possible
vllm/_custom_ops.py
Outdated
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz | ||
out_dtype: torch.dtype = torch.float8_e4m3fnuz if torch.version.hip is not None \ | ||
else torch.float8_e4m3fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please use is_hip()
for this?
from vllm.utils import is_hip
out_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
# If rocm, use float8_e4m3fnuz | ||
if isinstance(current_platform, rocm.RocmPlatform): | ||
weight_as_int8 = layer.weight.view(torch.int8) | ||
weight_as_int8[weight_as_int8 == -128] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you pull out -128
as a named constant then? ROCM_FP8_NAN = -128
or whatever is accurate here
# If rocm, adjust the scaling factor | ||
if isinstance(current_platform, rocm.RocmPlatform): | ||
layer.weight_scale = Parameter(layer.weight_scale * 2, | ||
requires_grad=False) | ||
if layer.input_scale is not None: | ||
layer.input_scale = Parameter(layer.input_scale * 2, | ||
requires_grad=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could leave a comment describing a reference for the * 2
, would be appreciated
scale_b=weight_scale, | ||
bias=bias) | ||
# Since in torch 2.5, scaled_mm only returns single value | ||
# This should be removed when vllm-nivida also moves to 2.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# This should be removed when vllm-nivida also moves to 2.5 | |
# This should be removed when vllm-nvidia also moves to 2.5 |
FYI you may want to look at #7233 to review the usage of hipcub for reduction |
# the e4m3fn value, so we should double the scaling factor to | ||
# get the same dequantized value. | ||
# https://onnx.ai/onnx/technical/float8.html | ||
if is_hip(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a utility function in w8a8_utils.py
called convert_to_e4m3fnuz
for this?
@@ -182,6 +182,17 @@ def process_weights_after_loading(self, layer: Module) -> None: | |||
# If checkpoint is fp8, handle that there are N scales for N | |||
# shards in a fused module | |||
else: | |||
# If rocm, use float8_e4m3fnuz. | |||
if is_hip(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here ---> lets have a utility function for this (so that we can use it in compressed tensors too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HaiShaw - are there any limitations for ops.scaled_fp8_quant
on hip
? Can we cover all cases?
- static
- dynamic per tensor
- dynamic per token
|
||
# If rocm, use float8_e4m3fnuz. | ||
if is_hip(): | ||
weight, weight_scale, input_scale = convert_to_e4m3fnuz( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please rename to normalize_e4m3fn_to_e4m3fnuz
or convert_e4m3fn_to_e4m3fnuz
* asserting input is e4m3fn
Get numbers on FP16 models for comparison? |
@HaiShaw those numbers match up with the base fp16 and fp8 (on H100) evals I shared earlier for those models, so I think this is sufficient! For reference: |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me! It would be nice if you could add an FP8 model loading test to the AMD CI, so we are testing beyond just the kernel support. You are welcome to do this as a followup PR
Thank you. We will add the test in a subsequent PR. |
Hi @HaiShaw do you know how does ROCM6.2 support fp8 kernels in vLLM? Are there any examples or jobs work in progress about fp8 ops writing (GEMM) with ROCM6.2 SDK ? |
@yiakwy-xpu-ml-framework-team ROCm6.2 supports fp8 natively via hipBLALt, triton and CK (not brought in vLLM use yet). |
Glad to see this. I was just curious about fp8_max in fp8_e4m3fnuz:
or
And I am looking for more about fp8 features. |
…7210) Signed-off-by: Alvant <[email protected]>
…7210) Signed-off-by: LeiWang1999 <[email protected]>
This PR adds fp8 linear layer support on Rocm.
torch.float8_e4m3fnuz
for fp8 data type for rocm, and update the fp8 conversion kernels accordingly.torch.float8_e4m3fn
totorch.float8_e4m3fnuz
after weight loading, and adjust the scaling factor as well._scaled_mm
returns single value, do a condition check when return the result from_scaled_mm
.Evaluation results:
Meta-Llama-3-8B-Instruct-FP8
Qwen2-7B-Instruct-FP8
Mixtral-8x7B-Instruct-v0.1-FP8