Enable F8E4M3 conversions on Nvidia GPUs with sm < 89, and fix F8E5M2 conversions #7904
+487
−149
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Nvidia GPUs with sm < 89 are still widely used, see e.g. Steam hardware survey. When running large AI models, a common usage is to store the parameters in fp8, and cast them to fp16 for computation on hardware that doesn't have native fp8. This reduces the memory requirement, even though no speed advantage. This PR aims to enable
torch.compile
on this usage.We may refer to XLA's fallback mechanism for fp8 operations, see openxla/xla#23124 , although I think we only need to support the conversions rather than all arithmetic operations.
Implementation
Before #2105 , there were some PTX code for converting F8E4M3/F8E5M2 <-> F16/BF16, but they did not correctly handle denormalized values and rounding to nearest even (RTNE). I've fixed these cases, and added the code for F32 -> F8E4M3/F8E5M2.
I've tested that for all 2^8 F8E4M3/F8E5M2 values, all 2^16 F16/BF16 values, and all 2^32 F32 values, the conversion results are bitwise identical to the PyTorch implementation, except some glitches about inf and nan, see the comments. The tests in
test_conversions.py
are passed.I've checked that all unit tests are passed on RTX 3080 (sm86). There is no IR change for sm >= 90. For sm89, there is a minor change that previously F32 -> F8E4M3/F8E5M2 was implemented by F32 -> F16 -> F8E4M3/F8E5M2 without correct RTNE, now it's directly implemented with RTNE.
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)