Skip to content

[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm #7210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 50 commits into from
Aug 16, 2024

Conversation

charlifu
Copy link
Contributor

@charlifu charlifu commented Aug 6, 2024

This PR adds fp8 linear layer support on Rocm.

  • Use torch.float8_e4m3fnuz for fp8 data type for rocm, and update the fp8 conversion kernels accordingly.
  • Adjust the weight from torch.float8_e4m3fn to torch.float8_e4m3fnuz after weight loading, and adjust the scaling factor as well.
  • Since rocm uses torch 2.5, in which _scaled_mm returns single value, do a condition check when return the result from _scaled_mm.

Evaluation results:
Meta-Llama-3-8B-Instruct-FP8

Tasks Version Filter n-shot Metric Value Stderr
Open LLM Leaderboard N/A
- arc_challenge 1 none 25 acc 0.5776 ± 0.0144
none 25 acc_norm 0.6203 ± 0.0142
- gsm8k 3 flexible-extract 5 exact_match 0.7597 ± 0.0118
strict-match 5 exact_match 0.7627 ± 0.0117
- hellaswag 1 none 10 acc 0.5879 ± 0.0049
none 10 acc_norm 0.7843 ± 0.0041
- mmlu 2 none acc 0.6649 ± 0.0038
- truthfulqa_mc2 2 none 0 acc 0.5257 ± 0.0153
- winogrande 1 none 5 acc 0.7601 ± 0.0120

Qwen2-7B-Instruct-FP8

Tasks Version Filter n-shot Metric Value Stderr
Open LLM Leaderboard N/A
- arc_challenge 1 none 25 acc 0.5836 ± 0.0144
none 25 acc_norm 0.6212 ± 0.0142
- gsm8k 3 flexible-extract 5 exact_match 0.7779 ± 0.0114
strict-match 5 exact_match 0.7066 ± 0.0125
- hellaswag 1 none 10 acc 0.6097 ± 0.0049
none 10 acc_norm 0.8108 ± 0.0039
- mmlu 2 none acc 0.7008 ± 0.0037
- truthfulqa_mc2 2 none 0 acc 0.5695 ± 0.0154
- winogrande 1 none 5 acc 0.7419 ± 0.0123

Mixtral-8x7B-Instruct-v0.1-FP8

Tasks Version Filter n-shot Metric Value Stderr
Open LLM Leaderboard N/A
- arc_challenge 1 none 25 acc 0.6664 ± 0.0138
none 25 acc_norm 0.6928 ± 0.0135
- gsm8k 3 flexible-extract 5 exact_match 0.6270 ± 0.0133
strict-match 5 exact_match 0.6224 ± 0.0134
- hellaswag 1 none 10 acc 0.6800 ± 0.0047
none 10 acc_norm 0.8718 ± 0.0033
- mmlu 2 none acc 0.6981 ± 0.0036
- truthfulqa_mc2 2 none 0 acc 0.6434 ± 0.0150
- winogrande 1 none 5 acc 0.8256 ± 0.0107

Copy link

github-actions bot commented Aug 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

# If rocm, use float8_e4m3fnuz
if isinstance(current_platform, rocm.RocmPlatform):
weight_as_int8 = layer.weight.view(torch.int8)
weight_as_int8[weight_as_int8 == -128] = 0
Copy link
Contributor

@HaiShaw HaiShaw Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not expect NaN (127 and 255) in weights for inference, could flag an exception for fail safe.
Reset -0 to 0 looks good.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0b10000000 is -0 in OCP, but is a NaN on NANOO, it needs this special handling

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pull out -128 as a named constant then? ROCM_FP8_NAN = -128 or whatever is accurate here

@charlifu
Copy link
Contributor Author

charlifu commented Aug 6, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 6, 2024
Copy link
Contributor

@HaiShaw HaiShaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write restrictions of this approach (with scaled_mm) in PR message body, such as: no output scaling is possible when needed, etc.

@@ -21,20 +28,26 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return old;
}

#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
#define FP8_E4M3_MAX std::numeric_limits<FP8_TYPE>::max()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N.B. AMD uses up to 224 not 240

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this should be handled by torch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 it would be best to leave this up to torch if possible

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is bit too strong to ask torch to make the change, as our current HW does support up to 240, nothing blocks it is being used. here in inference stack to maintain Interop with OCP via manipulating scaling factors, we introduced scale-by-2, to assure that I prefer us stick to 224 (maps to OCP-MAX 448).

Not ask us to change real value of numeric_limits<c10::Float8_e4m3fnuz>::max(), just ask to have a redef here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I see the justification now, that is fine with me. Please just leave a comment explaining the special definition

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N.B. AMD uses up to 224 not 240

Fix me if I am wrong, as far as I know fp8 has binary format 0bSEEEEMMM, without support of +- infinite number, exponent bias = 2^(4-1)=8

Note NVIDIA uses 0b1111.111 to represent NaNs, so they use 0b1111.110 to represent fp8 max.

Hence, our fp8_max (max normal) = 0b01111 111 with bias 8, computed as:

fp8_max = (-1)^S * 2^{e - bias}*1.f_2 = 2^{15 - 8} * (1+0.875)
fp8_max = 240

with 224, I guess you ar.e computing with 0b01111 110, hence 224=2^{15 - 8} * (1+0.75).

That means 0b0111 111 used to represent special numbers. However in MLIR fp8 's nan is 0b10000000.

see https://github.com/jax-ml/ml_dtypes#float8_e4m3fnuz

also

https://github.com/pytorch/pytorch/blob/main/c10/util/Float8_e4m3fnuz.h

@@ -74,8 +87,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale,
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
atomicMaxFloat(scale, cache[0] / std::numeric_limits<FP8_TYPE>::max());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In case AMD: note 224 vs 240

@@ -21,20 +28,26 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return old;
}

#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
#define FP8_E4M3_MAX std::numeric_limits<FP8_TYPE>::max()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 it would be best to leave this up to torch if possible

Comment on lines 347 to 349
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if torch.version.hip is not None \
else torch.float8_e4m3fn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please use is_hip() for this?

from vllm.utils import is_hip

    out_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn

# If rocm, use float8_e4m3fnuz
if isinstance(current_platform, rocm.RocmPlatform):
weight_as_int8 = layer.weight.view(torch.int8)
weight_as_int8[weight_as_int8 == -128] = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you pull out -128 as a named constant then? ROCM_FP8_NAN = -128 or whatever is accurate here

Comment on lines 218 to 224
# If rocm, adjust the scaling factor
if isinstance(current_platform, rocm.RocmPlatform):
layer.weight_scale = Parameter(layer.weight_scale * 2,
requires_grad=False)
if layer.input_scale is not None:
layer.input_scale = Parameter(layer.input_scale * 2,
requires_grad=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could leave a comment describing a reference for the * 2, would be appreciated

scale_b=weight_scale,
bias=bias)
# Since in torch 2.5, scaled_mm only returns single value
# This should be removed when vllm-nivida also moves to 2.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This should be removed when vllm-nivida also moves to 2.5
# This should be removed when vllm-nvidia also moves to 2.5

@mgoin
Copy link
Member

mgoin commented Aug 7, 2024

FYI you may want to look at #7233 to review the usage of hipcub for reduction

# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
if is_hip():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a utility function in w8a8_utils.py called convert_to_e4m3fnuz for this?

@@ -182,6 +182,17 @@ def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
# If rocm, use float8_e4m3fnuz.
if is_hip():
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here ---> lets have a utility function for this (so that we can use it in compressed tensors too)

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HaiShaw - are there any limitations for ops.scaled_fp8_quant on hip? Can we cover all cases?

  • static
  • dynamic per tensor
  • dynamic per token


# If rocm, use float8_e4m3fnuz.
if is_hip():
weight, weight_scale, input_scale = convert_to_e4m3fnuz(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please rename to normalize_e4m3fn_to_e4m3fnuz or convert_e4m3fn_to_e4m3fnuz

@HaiShaw
Copy link
Contributor

HaiShaw commented Aug 15, 2024

This PR adds fp8 linear layer support on Rocm.

  • Use torch.float8_e4m3fnuz for fp8 data type for rocm, and update the fp8 conversion kernels accordingly.
  • Adjust the weight from torch.float8_e4m3fn to torch.float8_e4m3fnuz after weight loading, and adjust the scaling factor as well.
  • Since rocm uses torch 2.5, in which _scaled_mm returns single value, do a condition check when return the result from _scaled_mm.

Evaluation results: Meta-Llama-3-8B-Instruct-FP8

Tasks Version Filter n-shot Metric Value Stderr
Open LLM Leaderboard N/A

  • arc_challenge 1 none 25 acc ↑ 0.5776 ± 0.0144
    none 25 acc_norm ↑ 0.6203 ± 0.0142
  • gsm8k 3 flexible-extract 5 exact_match ↑ 0.7597 ± 0.0118
    strict-match 5 exact_match ↑ 0.7627 ± 0.0117
  • hellaswag 1 none 10 acc ↑ 0.5879 ± 0.0049
    none 10 acc_norm ↑ 0.7843 ± 0.0041
  • mmlu 2 none acc ↑ 0.6649 ± 0.0038
  • truthfulqa_mc2 2 none 0 acc ↑ 0.5257 ± 0.0153
  • winogrande 1 none 5 acc ↑ 0.7601 ± 0.0120
    Qwen2-7B-Instruct-FP8

Tasks Version Filter n-shot Metric Value Stderr
Open LLM Leaderboard N/A

  • arc_challenge 1 none 25 acc ↑ 0.5836 ± 0.0144
    none 25 acc_norm ↑ 0.6212 ± 0.0142
  • gsm8k 3 flexible-extract 5 exact_match ↑ 0.7779 ± 0.0114
    strict-match 5 exact_match ↑ 0.7066 ± 0.0125
  • hellaswag 1 none 10 acc ↑ 0.6097 ± 0.0049
    none 10 acc_norm ↑ 0.8108 ± 0.0039
  • mmlu 2 none acc ↑ 0.7008 ± 0.0037
  • truthfulqa_mc2 2 none 0 acc ↑ 0.5695 ± 0.0154
  • winogrande 1 none 5 acc ↑ 0.7419 ± 0.0123
    Mixtral-8x7B-Instruct-v0.1-FP8

Tasks Version Filter n-shot Metric Value Stderr
Open LLM Leaderboard N/A

  • arc_challenge 1 none 25 acc ↑ 0.6664 ± 0.0138
    none 25 acc_norm ↑ 0.6928 ± 0.0135
  • gsm8k 3 flexible-extract 5 exact_match ↑ 0.6270 ± 0.0133
    strict-match 5 exact_match ↑ 0.6224 ± 0.0134
  • hellaswag 1 none 10 acc ↑ 0.6800 ± 0.0047
    none 10 acc_norm ↑ 0.8718 ± 0.0033
  • mmlu 2 none acc ↑ 0.6981 ± 0.0036
  • truthfulqa_mc2 2 none 0 acc ↑ 0.6434 ± 0.0150
  • winogrande 1 none 5 acc ↑ 0.8256 ± 0.0107

Get numbers on FP16 models for comparison?

@charlifu charlifu requested a review from mgoin August 15, 2024 14:47
@mgoin
Copy link
Member

mgoin commented Aug 15, 2024

@HaiShaw those numbers match up with the base fp16 and fp8 (on H100) evals I shared earlier for those models, so I think this is sufficient!

For reference:
neuralmagic/Meta-Llama-3-8B-Instruct-FP8
neuralmagic/Qwen2-7B-Instruct-FP8
neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8

@HaiShaw
Copy link
Contributor

HaiShaw commented Aug 15, 2024

@HaiShaw those numbers match up with the base fp16 and fp8 (on H100) evals I shared earlier for those models, so I think this is sufficient!

For reference: neuralmagic/Meta-Llama-3-8B-Instruct-FP8 neuralmagic/Qwen2-7B-Instruct-FP8 neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8

@mgoin , thanks for your confirmation!
LGTM! @charlifu

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! It would be nice if you could add an FP8 model loading test to the AMD CI, so we are testing beyond just the kernel support. You are welcome to do this as a followup PR

@charlifu
Copy link
Contributor Author

Looks good to me! It would be nice if you could add an FP8 model loading test to the AMD CI, so we are testing beyond just the kernel support

Thank you. We will add the test in a subsequent PR.

@simon-mo simon-mo merged commit e837b62 into vllm-project:main Aug 16, 2024
66 of 70 checks passed
@charlifu charlifu deleted the amd_fp8 branch August 16, 2024 19:46
@yiakwy-xpu-ml-framework-team

@mgoin @robertgshaw2-neuralmagic additionally we plan to support https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 soon (we already supports https://huggingface.co/amd/Meta-Llama-3.1-405B-Instruct-fp8-quark-vllm but not from this PR), when FBGEMM-FP8 (dynamic per-token activations and per-channel weights) support is ready, FYI.

Hi @HaiShaw do you know how does ROCM6.2 support fp8 kernels in vLLM? Are there any examples or jobs work in progress about fp8 ops writing (GEMM) with ROCM6.2 SDK ?

@HaiShaw
Copy link
Contributor

HaiShaw commented Sep 11, 2024

@yiakwy-xpu-ml-framework-team ROCm6.2 supports fp8 natively via hipBLALt, triton and CK (not brought in vLLM use yet).
Current MI300 FP8 format is somehow different than OCP format, we introduced max. ceiling and scaling factor adjustment to make sure it receives OCP standard FP8 data from external interfaces (checkpoints, etc.) and compute on native AMD HW.
In terms of work or tasks, we welcome all kinds of discussion and collaborations 😄

@yiakwy-xpu-ml-framework-team

@yiakwy-xpu-ml-framework-team ROCm6.2 supports fp8 natively via hipBLALt, triton and CK (not brought in vLLM use yet). Current MI300 FP8 format is somehow different than OCP format, we introduced max. ceiling and scaling factor adjustment to make sure it receives OCP standard FP8 data from external interfaces (checkpoints, etc.) and compute on native AMD HW. In terms of work or tasks, we welcome all kinds of discussion and collaborations 😄

Glad to see this. I was just curious about fp8_max in fp8_e4m3fnuz:

240 for : 0b01111 111 = (-1)^S * 2^{e - bias}*1.f_2 = 2^{15 - 8} * (1+0.875)

or

224 for : 0b01111 110 = 2^{15 - 8} * (1+0.75)

And I am looking for more about fp8 features.

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants