Skip to content

Conversation

zcbenz
Copy link
Collaborator

@zcbenz zcbenz commented Jul 10, 2025

This PR ports the Scan kernels from Metal to CUDA.

There are also a few changes to fix the LogAddExp op for complex numbers which caused the scan tests to fail:

  • A few C++ tests are added to show the failed case.
  • A cexpf.cuh file is added to port the complex number exp from thrust, we can not use thrust::exp directly unfortunately because it does not compile in JIT kernels when building with old CUDA versions (<= 12.2).
  • The complex number log1p is moved from utils.cuh to unary_ops.cuh so it is put together with other complex number functions.
  • The LogAddExp's implementation is updated to use PyTorch's algorithm.

@awni
Copy link
Member

awni commented Jul 10, 2025

Looks awesome! I'm wondering if those logaddexp tests will fail on Metal / CPU as well?

@zcbenz
Copy link
Collaborator Author

zcbenz commented Jul 10, 2025

The exp test fails on Metal, should I port the cexpf?

@awni
Copy link
Member

awni commented Jul 10, 2025

The exp test fails on Metal, should I port the cexpf?

Yes that would be super useful.. it looks like a bug as is.

@zcbenz
Copy link
Collaborator Author

zcbenz commented Jul 10, 2025

Updated to fix the exp on Metal.

@awni awni merged commit 8347575 into ml-explore:main Jul 10, 2025
5 checks passed
@zcbenz zcbenz deleted the cuda-scan branch July 13, 2025 00:48
Jckwind pushed a commit to TheProxyCompany/mlx that referenced this pull request Aug 28, 2025
* Contiguous scan

* Strided scan

* Enable tests

* Fix failing logaddexp test

* Use cexpf in Metal
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants