-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Add torch.compile support to flash attention 3 #1769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add torch.compile support to flash attention 3 #1769
Conversation
hopper/flash_api.cpp
Outdated
@@ -1511,7 +1511,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tenso | |||
softmax_d.zero_(); | |||
} | |||
|
|||
return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; | |||
return { dq.clone(), dk.clone(), dv.clone(), softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's going on here? Adding additional clones is generally bad for performance
hopper/build.sh
Outdated
# Flash Attention Minimal Build Script for PHI-1 Reproducer | ||
# Uses subshell to automatically clean up environment variables | ||
|
||
# Run in subshell - variables are automatically cleaned up when it exits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have more context for this file? I don't see a mention of PHI-1 elsewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll remove this file once the PR gets approved. I only keep it as it makes it easier to build and test flash attention.
hopper/test_flash_attn.py
Outdated
if not DISABLE_FAKE_CHECK: | ||
flash_attn_func = run_fake_check(flash_attn_func) | ||
flash_attn_varlen_func = run_fake_check(flash_attn_varlen_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit jank but I'm fine with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tridao, are you ok with running fake_check
as part of the tests? I can invert the flag to be opt-in instead of opt-out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would that slow down the tests much? Rn it takes 30-50mins to run all the tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it can double the time to run the tests, as fake_check would run the function twice to compare the fake version with the actual implementation. I can change this flag to be opt-in instead. So, to run fake_check
, one should enable the flag.
fake_check(fn, args, kwargs) | ||
return fn(*args, **kwargs) | ||
return wrapper | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a sense of how comprehensive the existing tests in this file are? Are they good at exercising a variety of inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess @tridao would be a better person to answer this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do test a lot of input shapes and different attn options (~100k tests iirc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, modulo the prints
@guilhermeleobas Anything update? I tried your commits and found there are still some shape errors in the backward function, here is my test code.
batch_size = 16 sequence_length = 256 test = EfficienctMultiHeadAttention(embedding_dim, num_heads=16).cuda().bfloat16() out = test(input_tensor) |
Hi @lantudou, thanks for the reproducer. Could you try it again? |
@tridao, could you take a look one more time once you have some cycles to spare? |
I think the torch custom ops should be registered to a |
Is flash attention 3 an independent package of FA2? In the sense that in the future FA2 will be deprecated in favor of FA3? |
idk what the future plan is but at least right now transformers thinks it can have both at the same time |
Got it. I changed the namespace to |
Thank you for the great work @guilhermeleobas. Just wondering have you tried to run training on the export FX graph? It seems like this implementation still missing register_autograd in order to run the training loop. |
Thanks for the feedback, @Tomcli. As for the second question. Probably it won't work. I based my implementation on what is implemented for FA2, which doesn't use Edit: added in this one. |
Hi @Tomcli, could you try the last commit, please? |
hopper/flash_attn_interface.py
Outdated
is_varlen_q = bool(cu_seqlens_q) | ||
is_varlen_k = bool(cu_seqlens_k) | ||
is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_varlen_q = bool(cu_seqlens_q) | |
is_varlen_k = bool(cu_seqlens_k) | |
is_varlen = is_varlen_q or is_varlen_k or bool(sequed_q) or bool(sequed_k) | |
is_varlen_q = cu_seqlens_q is not None | |
is_varlen_k = cu_seqlens_k is not None | |
is_varlen = is_varlen_q or is_varlen_k or sequed_q is not None or sequed_k is not None |
…se around `torch.library.custom_op` usage
871a3cf
to
e52508f
Compare
Enable torch.compile support for FlashAttention and improve testing
flash_attn_config.py
)