-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add torch.compile support to flash attention 3 #1769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add torch.compile support to flash attention 3 #1769
Conversation
hopper/flash_api.cpp
Outdated
} | ||
|
||
return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; | ||
return { dq.clone(), dk.clone(), dv.clone(), softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here? Adding additional clones is generally bad for performance
hopper/build.sh
Outdated
# Flash Attention Minimal Build Script for PHI-1 Reproducer | ||
# Uses subshell to automatically clean up environment variables | ||
|
||
# Run in subshell - variables are automatically cleaned up when it exits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have more context for this file? I don't see a mention of PHI-1 elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove this file once the PR gets approved. I only keep it as it makes it easier to build and test flash attention.
hopper/test_flash_attn.py
Outdated
if not DISABLE_FAKE_CHECK: | ||
flash_attn_func = run_fake_check(flash_attn_func) | ||
flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit jank but I'm fine with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tridao, are you ok with running fake_check
as part of the tests? I can invert the flag to be opt-in instead of opt-out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would that slow down the tests much? Rn it takes 30-50mins to run all the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it can double the time to run the tests, as fake_check would run the function twice to compare the fake version with the actual implementation. I can change this flag to be opt-in instead. So, to run fake_check
, one should enable the flag.
fake_check(fn, args, kwargs) | ||
return fn(*args, **kwargs) | ||
return wrapper | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a sense of how comprehensive the existing tests in this file are? Are they good at exercising a variety of inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess @tridao would be a better person to answer this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do test a lot of input shapes and different attn options (~100k tests iirc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, modulo the prints
@guilhermeleobas Anything update? I tried your commits and found there are still some shape errors in the backward function, here is my test code.
batch_size = 16 sequence_length = 256 test = EfficienctMultiHeadAttention(embedding_dim, num_heads=16).cuda().bfloat16() out = test(input_tensor) |
Hi @lantudou, thanks for the reproducer. Could you try it again? |
@tridao, could you take a look one more time once you have some cycles to spare? |
I think the torch custom ops should be registered to a |
Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3? |
idk what the future plan is but at least right now transformers thinks it can have both at the same time |
Got it. I changed the namespace to |
Thank you for the great work @guilhermeleobas. Just wondering have you tried to run training on the export FX graph? It seems like this implementation still missing register_autograd in order to run the training loop. |
Thanks for the feedback, @Tomcli. As for the second question. Probably it won't work. I based my implementation on what is implemented for FA2, which doesn't use Edit: added in this one. |
Hi @Tomcli, could you try the last commit, please? |
…se around `torch.library.custom_op` usage
thanks for fixing it @guilhermeleobas guilhermeleobas, the compile now works, but the compiled artifacts still breaks me for when using AOTInductor packaging. |
Thanks for trying this PR @OutofAi. Do you have a reproducer for this error? |
I just want to share that for my workflow, based only on |
Hi @OutofAi, I believe I fixed the torch.export bug you're seeing. Could you try again, please? |
I tried torch.compile with flash_attn_with_kvcache, and received:
I wonder if flash_attn_with_kvcache is supported in this commit? |
hopper/flash_attn_interface.py
Outdated
window_size_right: int = -1, | ||
softcap: float = 0.0, | ||
deterministic: bool = False, | ||
deterministic: bool= False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks malformatted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
@guilhermeleobas just saw you messages, seems to be working now with your new changes, gave it a try with my sort of unity test and it works now, thanks, for reference this the code
|
Hi @yijianggit,
No, it is not. Implementing flash_attn call with num_splits < 0 requires one to implement the The heuristic for determining the number of splits depends on many variables, so replicating the exact C++ logic would be quite involved. While it’s technically feasible, a more practical approach might be to call the C++ code directly. Thanks for the repro @OutofAi |
Hi @guilhermeleobas, Here is the code snippet for reproduce:
Please let me know if there are things I can help support this. Thanks! |
Hi @tridao, could you take a look at this PR again? It seems to be working for most use cases. |
@guilhermeleobas This change does not work properly with My code snippet
The output error
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that #1791 is landed can you also update the stable file to keep them in sync?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done!
@haohaibo, thanks for the repro but your code is missing some things. What is |
Unfortunately the flash_api_stable.cpp code does not compile for me on any nightly version. I wrote a fork here that is both ABI compatible and includes the You can view the diff here. EDIT: I've also packaged/built my fork and uploaded it on Huggingface for each Torch version here: https://huggingface.co/varunneal/flash-attention-3, just in case it is useful for anyone. |
@varunneal when you say flash_api_stable.cpp does not compile for you—do you mean on main or on this branch? If on main that would be concerning and I’d like to find out more about ur setup so we could fix it. |
@janeyx99 Just on this branch |
Having this would be quite helpful, what's blocking it from being merged? @guilhermeleobas. |
On it. I'm on PTO this week and can work on this next week.
Thanks for testing it. I was working on a different branch without the latest changes, so the failures slipped by me and I didn’t notice. |
@varunneal could you give this PR another try? Building |
hopper/flash_attn_interface.py
Outdated
|
||
is_local = (window_size_left >= 0 or window_size_right >= 0) and not is_causal | ||
|
||
if arch < 90: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi. I do not think flash attention 3 rely on hopper. It works fine on older gpu like a100. Can you remove this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
Enable torch.compile support for FlashAttention and improve testing
flash_attn_config.py
)