Skip to content

Enable F8E4M3 conversions on Nvidia GPUs with sm < 89, and fix F8E5M2 conversions #7904

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@

input_dtypes = ["bfloat16", "float16", "float32"]
if is_cuda():
input_dtypes += ["int8", "float8_e5m2"]
cc = torch.cuda.get_device_capability(0)
if cc >= (8, 9):
input_dtypes += ["float8_e4m3fn"]
input_dtypes += ["int8", "float8_e5m2", "float8_e4m3fn"]
elif is_hip_cdna3():
input_dtypes += [
"int8",
Expand Down
4 changes: 1 addition & 3 deletions python/test/unit/language/test_compile_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,14 +355,12 @@ def kernel():
@pytest.mark.parametrize("dtype", [tl.float8e5, tl.float8e5b16, tl.float8e4nv, tl.float8e4b8, tl.float8e4b15])
def test_fp8_support(fresh_triton_cache, dtype):
warning_dtypes = []
supported_dtypes = [tl.float8e5]
supported_dtypes = [tl.float8e5, tl.float8e4nv]
if is_cuda():
cc = torch.cuda.get_device_capability(0)
supported_dtypes.append(tl.float8e4b15)
if cc >= (9, 0):
warning_dtypes.append(tl.float8e4b15)
if cc >= (8, 9):
supported_dtypes.append(tl.float8e4nv)
elif is_hip():
supported_dtypes += [tl.float8e4nv, tl.float8e4b8, tl.float8e5b16]
if is_hip_cdna4():
Expand Down
9 changes: 1 addition & 8 deletions python/test/unit/language/test_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
# On HIP, fp8e4nv upcasting to fp32 is only supported on CDNA4, and
# fp8e4nv upcasting to bf16 and fp16 is only supported on CDNA3 and CDNA4.
if is_cuda():
if ((src_dtype == 'float8e4nv' and torch.cuda.get_device_capability(0) < (8, 9))
or src_dtype in ('float8e4b8', 'float8e5b16')):
if src_dtype in ('float8e4b8', 'float8e5b16'):
# If the dtype should error out in the given device, we assert that and return
with pytest.raises(triton.CompilationError, match="not supported in this architecture"):
launch_exhaustive_populate(getattr(tl, src_dtype), 0, 65536, False, 8, 0x7f, device=device)
Expand Down Expand Up @@ -333,12 +332,6 @@ def test_typeconvert_upcast(src_dtype, dst_dtype, device):
def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):

if is_cuda():
if src_dtype != 'float32' and torch.cuda.get_device_capability(0) < (9, 0):
pytest.skip("non-float32 downcast tests only supported on NVGPU with compute capability 9.0+")

if dst_dtype in ('float8e5', 'float8e4nv') and rounding == 'rtne' and torch.cuda.get_device_capability(0) < (9, 0):
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on NVGPU with compute capability 9.0+")

if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne':
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")

Expand Down
2 changes: 0 additions & 2 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,6 @@ def test_abs_fp8(in_dtype, device):
cc = torch.cuda.get_device_capability()
if in_dtype == tl.float8e4b15 and cc >= (9, 0):
pytest.skip("float8e4b15 not supported on CUDA >= 9.0")
if in_dtype == tl.float8e4nv and cc < (8, 9):
pytest.skip("float8e4nv not supported on CUDA < 8.9")

@triton.jit
def abs_kernel(X, Z, SIZE: tl.constexpr):
Expand Down
4 changes: 1 addition & 3 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class CUDAOptions:
enable_fp_fusion: bool = True
launch_cooperative_grid: bool = False
launch_pdl: bool = False
supported_fp8_dtypes: Tuple[str] = ("fp8e5", "fp8e4b15")
supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
default_dot_input_precision: str = "tf32"
allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee")
Expand Down Expand Up @@ -181,8 +181,6 @@ def parse_options(self, opts) -> Any:

if "supported_fp8_dtypes" not in args:
supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
if capability >= 89:
supported_fp8_dtypes.add("fp8e4nv")
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))

if "deprecated_fp8_dot_operand_dtypes" not in args:
Expand Down
Loading