Skip to content

Conversation

@yicwang
Copy link

@yicwang yicwang commented Sep 28, 2025

This PR adds the support of using fp8 (e4m3) as the KV Cache in the FA4 attention (cute). High levelly, it relaxed the checks, and upcast the FP8 kvcache to FP16 to continue the computation. UT also included:

# pytest -s tests/cute/test_flash_attn.py::test_flash_attn_mixed_precision_q_bf16_kv_fp8
================================================== test session starts ===================================================
platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0
rootdir: /sgl-workspace/flash-attention/tests
configfile: pyproject.toml
plugins: typeguard-4.4.4, anyio-4.10.0
collected 8 items

tests/cute/test_flash_attn.py ........

==================================================== warnings summary ====================================================
<...>
=========================================== 8 passed, 1194 warnings in 19.13s ===========================================

@tridao
Copy link
Member

tridao commented Sep 28, 2025

This just converts K&V from fp8 -> bf16 outside the kernel (i.e. at the pytorch level) so there'll be no speed benefit.

@yicwang
Copy link
Author

yicwang commented Sep 29, 2025

Yes, you are right! This converts the dtype in the kernel, so it will still be faster compared to do so in say SGLang. But unit test does show no performance difference...

I see that our GEMM is written in PTX using inline_asm, and I am not sure if NV really does support the native mixed-precision GEMM. Do you have some pointers? I am very willing to dive deeper and have this implemented with better performance.

@tridao
Copy link
Member

tridao commented Sep 29, 2025

Huh this conversion happens in pytorch

  def to_cute(t, align=16, leading_dim=None):
        if t is None: return None
        leading_dim = t.ndim - 1 if leading_dim is None else leading_dim
        t = t.to(torch.bfloat16) if t.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] else t

So this is equivalent to calling flash_attn_func(q, k.to(torch.bfloat16), v.to(torch.bfloat16)).
This PR does not change the kernel and hence will not change the speed (compared to just casting K & V to bf16 in pytorch then call flash_attn_func).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants