Skip to content

Commit 465cb97

Browse files
authored
Enable FA V3 (#157)
* Compress PA work narrow pa test ref works on most cases inplace ref with new_kv inplace paged attention add pa ref save pa basic paged works save fix swa + causal in pa. Also new_kv only on pa path passing build fa v3 import interface from fa v3 copy fa tests use v3 api clean up rename to match old test support different head sizes remove fp8 basisc passing v3 cases test_flash_attn_varlen_output v3 working isolate bad case for kvcache case passing save use decode is seqused/ cacheseql is given use decode if not varlen basci kvcache v3 working kvcache enable more cases detect kvcache case if seqused_q is non and sequese_k is not None skip failing test find fp8 failing case mha fp8 works fix fp8 MQA/GQA bug clean up more clean up clean up more don't need fp8 dead code remove train code with fp8 stuff fp8 working in kvcache paged + fp8 seems to be working new_kv allowed * clean up * skip hopper race test * clean up more * fix paged + alibi * similar inner paged api * unify _attn_fwd_inner
1 parent 0607f30 commit 465cb97

21 files changed

+3021
-2658
lines changed

.github/workflows/amd_tests.yml

Lines changed: 0 additions & 65 deletions
This file was deleted.

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ Then install Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` s
160160

161161
```
162162
cd flash-attention
163-
git checkout main_perf
164163
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
165164
```
166165

@@ -184,16 +183,17 @@ WORKDIR /workspace
184183
# install triton
185184
RUN pip install triton==3.3.0
186185
187-
# install flash attention
188-
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
189-
190-
RUN git clone https://github.com/ROCm/flash-attention.git &&\
186+
# build flash attention with triton backend
187+
RUN git clone https://github.com/Dao-AILab/flash-attention &&\
191188
cd flash-attention &&\
192-
git checkout main_perf &&\
193-
python setup.py install
189+
FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" python setup.py install
194190
195191
# set working dir
196192
WORKDIR /workspace/flash-attention
193+
194+
# set env variable to use triton backend
195+
ENV FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
196+
197197
```
198198

199199
To build the docker file

flash_attn/flash_attn_triton_amd/.gitignore

Lines changed: 0 additions & 2 deletions
This file was deleted.

flash_attn/flash_attn_triton_amd/Dockerfile

Lines changed: 0 additions & 17 deletions
This file was deleted.

flash_attn/flash_attn_triton_amd/README.md

Lines changed: 0 additions & 113 deletions
This file was deleted.

flash_attn/flash_attn_triton_amd/bwd_prefill_fused_atomics.py

100644100755
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import triton
33
import triton.language as tl
4-
from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors
4+
from flash_attn.flash_attn_triton_amd.utils import compute_fp8_scaling_factors, DEBUG, is_fp8
55

66
from typing import Optional, Tuple
77

@@ -1503,11 +1503,17 @@ def attention_prefill_backward_triton_fused_atomics_impl(
15031503
descale_v: Optional[torch.Tensor] = None,
15041504
descale_do: Optional[torch.Tensor] = None,
15051505
fused: bool = False,
1506+
# seqused for FA v3 (currently ignored in this implementation)
1507+
seqused_q: Optional[torch.Tensor] = None,
1508+
seqused_k: Optional[torch.Tensor] = None,
15061509
):
15071510
IS_FP8 = is_fp8(q)
15081511
if IS_FP8:
15091512
FP8_MAX = torch.finfo(q.dtype).max
15101513
descale_strides = (descale_q.stride(0),descale_k.stride(0),descale_v.stride(0),descale_do.stride(0) )
1514+
1515+
if DEBUG:
1516+
print(f"FP8 path triggered in bwd_prefill_fused_atomics.py")
15111517
else:
15121518
FP8_MAX = None
15131519
stride_descale_q_z = stride_descale_k_z = stride_descale_v_z = stride_descale_do_z = None

0 commit comments

Comments
 (0)