-
Notifications
You must be signed in to change notification settings - Fork 485
feat: cutlass fp8 gemm bringup for SM120 & SM121 #1610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
self, | ||
inputs: List[torch.Tensor], | ||
tactic: int = -1, | ||
do_preparation: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like this parameter is unused?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. They’re part of TunableRunner, keeping them for consistency with the others.
296583e
to
c172dfd
Compare
looks good no further comments from me |
csrc/gemm_groupwise_sm120.cu
Outdated
constexpr int SCALE_GRANULARITY_M = 1; /* Always 1 for SM120 */ \ | ||
constexpr int SCALE_GRANULARITY_K = 128; /* Always 128 for SM120 per CUTLASS requirement */ \ | ||
if (scale_granularity_m != 1) { \ | ||
TORCH_CHECK(false, "SM120 only supports scale_granularity_m=1"); \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't see this constraint in https://github.com/NVIDIA/cutlass/blob/b2dd65dc864e09688245b316ac46c4a6cd07e15c/include/cutlass/gemm/collective/sm120_mma_tma_blockwise_scaling.hpp, what's the error message if you set it to 128?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will run into static assertion failed with "Scale Granularity M must evenly divide the tile shape M."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This error should only happend when ScaleGranularityM == 128
and TileShapeM = 64
:
which is the case when MmaSM == 1
here: https://github.com/flashinfer-ai/flashinfer/pull/1610/files#diff-0977093a8d2429e66dab4cc40f31563717098cb5aca4354a814e4208f58f068bR78, which you disabled in https://github.com/flashinfer-ai/flashinfer/pull/1610/files#diff-68929275a79ec730031c1d5bec894f35ba6e932a08841fdae63087a6937c0f4fR70
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right. I used a standalone test (not in this pr) to trigger that error message. will go with the #1610 (comment)
There might be some misunderstanding of MmaSM here: we use it in sm100 gemm because sm100 supports tcgen05 and 2-cta mode (where 2 ctas cooperatively perform a mma computation). However, sm120 do not have tcgen05 and 2-cta mma, so MmaSM doesn't make sense here, it should always be 1 in sm120. The condition used separate to the two cases in https://github.com/flashinfer-ai/flashinfer/pull/1610/files#diff-0977093a8d2429e66dab4cc40f31563717098cb5aca4354a814e4208f58f068bR78 should not be Please consider the following changes:
|
98fdd75
to
3a5cd77
Compare
Thanks, @yzh119, @nvmbreughe , @aleozlx for the helpful and insightful comments! I’ve incorporated them. Please take a look. For the PingPong gemm, I left it as a todo for now; the current default in the cutlass examples is cooperative gemm. |
csrc/gemm_groupwise_sm120.cu
Outdated
constexpr int SCALE_GRANULARITY_K = 128; /* equal tile K dimension*/ \ | ||
if (scale_granularity_m != 1) { \ | ||
TORCH_CHECK(false, \ | ||
"SM120 only supports scale_granularity_m=1 to ensure compatibility with all " \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still the case after your changes? If not, let's add 128 back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! A divisor of 128 should be valid. I added scale_granularity_m=1 and scale_granularity_m=128, the change was: https://github.com/flashinfer-ai/flashinfer/pull/1610/files#diff-68929275a79ec730031c1d5bec894f35ba6e932a08841fdae63087a6937c0f4fR44-R52
b0ed858
to
5f40472
Compare
3c22b87
to
305be25
Compare
📌 Description
It depends on #1608, mainly the cutlass fp8 gemm support for sm120/121, will rebase after #1608 lands.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commit
by runningpip install pre-commit
(or used your preferred method).pre-commit install
.pre-commit run --all-files
and fixed any reported issues.🧪 Tests
unittest
, etc.).Reviewer Notes