Skip to content

Update CUDA sdpa #2468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 7, 2025
Merged

Update CUDA sdpa #2468

merged 9 commits into from
Aug 7, 2025

Conversation

jagrit06
Copy link
Member

@jagrit06 jagrit06 commented Aug 6, 2025

Proposed changes

  • Add one and 2 pass vector sdpa impelmentations in Cuda
  • Add cudnn for matrix attention in supported types / hardware

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@jagrit06 jagrit06 requested a review from awni August 6, 2025 22:08
Comment on lines 811 to 831
inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
switch (dtype) {
case int8:
return fe::DataType_t::INT8;
case int32:
return fe::DataType_t::INT32;
case uint8:
return fe::DataType_t::UINT8;
case float16:
return fe::DataType_t::HALF;
case bfloat16:
return fe::DataType_t::BFLOAT16;
case float32:
return fe::DataType_t::FLOAT;
case float64:
return fe::DataType_t::DOUBLE;
default:
throw std::runtime_error(fmt::format(
"Unsupported dtype in SDPA: {}.", dtype_to_string(dtype)));
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can refactor that into a shared header cuddn.h that gets reused by conv.cpp?

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome! Let's merge it after you fix the compile issue and the tests clear

@angeloskath angeloskath merged commit a9bdd67 into main Aug 7, 2025
6 checks passed
@angeloskath angeloskath deleted the sdpav-base branch August 7, 2025 04:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants