Skip to content

Enable F8E4M3 conversions on Nvidia GPUs with sm < 89, and fix F8E5M2 conversions #7904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

woct0rdho
Copy link
Contributor

@woct0rdho woct0rdho commented Aug 19, 2025

Motivation

Nvidia GPUs with sm < 89 are still widely used, see e.g. Steam hardware survey. When running large AI models, a common usage is to store the parameters in fp8, and cast them to fp16 for computation on hardware that doesn't have native fp8. This reduces the memory requirement, even though no speed advantage. This PR aims to enable torch.compile on this usage.

We may refer to XLA's fallback mechanism for fp8 operations, see openxla/xla#23124 , although I think we only need to support the conversions rather than all arithmetic operations.

Implementation

Before #2105 , there were some PTX code for converting F8E4M3/F8E5M2 <-> F16/BF16, but they did not correctly handle denormalized values and rounding to nearest even (RTNE). I've fixed these cases, and added the code for F32 -> F8E4M3/F8E5M2.

I've tested that for all 2^8 F8E4M3/F8E5M2 values, all 2^16 F16/BF16 values, and all 2^32 F32 values, the conversion results are bitwise identical to the PyTorch implementation, except some glitches about inf and nan, see the comments. The tests in test_conversions.py are passed.

I've checked that all unit tests are passed on RTX 3080 (sm86). There is no IR change for sm >= 90. For sm89, there is a minor change that previously F32 -> F8E4M3/F8E5M2 was implemented by F32 -> F16 -> F8E4M3/F8E5M2 without correct RTNE, now it's directly implemented with RTNE.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason we haven't been supporting e4m3 on those targets is because it is inefficient so we don't want to give user an impression that this is natively supported or has efficient conversion.
Conversion can be done at kernel level for users that want to support this format so I'm not sure there is a strong motivation to have the emulation in Triton.

"add.u32 a0, a0, 0x00800080; \n" // a0 += 0x00800080
"add.u32 a1, a1, 0x00800080; \n" // (round to nearest)
"prmt.b32 $0, a0, a1, 0x7531; \n\t" // output = a1a0
".reg .b32 a<2>, b<2>; \n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing the conversion of e5m2 will have a large impact on BC compatibility precision and performance. This is going to be a problem for OAI internally, I don't think we want to change it, at least not in the PR.

@woct0rdho
Copy link
Contributor Author

woct0rdho commented Aug 19, 2025

Thank you and I understand your concern. Do you think there is a way to 'inject' these kernels when doing triton.jit or torch.compile, or add a custom pass, without modifying Triton itself? (And preferably, avoid graph break in torch.compile and retain graph-level optimizations?)

Also I'd say non-standard rounding may cause some surprising compatibility issue, such as 'why I suddenly get noise when enabling torch.compile', although I can't yet show a typical case.

@ThomasRaoux
Copy link
Collaborator

Thank you and I understand your concern. Do you think there is a way to 'inject' these kernels when doing triton.jit or torch.compile, or add a custom pass, without modifying Triton itself? (And preferably, avoid graph break in torch.compile and retain graph-level optimizations?)

Also I'd say non-standard rounding may cause some surprising compatibility issue, such as 'why I suddenly get noise when enabling torch.compile', although I can't yet show a typical case.

it should be easy to make a pass on ttir or ttgir to transform it in supported IR using elementwise_inline_asm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants