Skip to content

Commit f78bc28

Browse files
zcbenzJckwind
authored andcommitted
[CUDA] Implement Scan kernel (ml-explore#2347)
* Contiguous scan * Strided scan * Enable tests * Fix failing logaddexp test * Use cexpf in Metal
1 parent 14a3c71 commit f78bc28

File tree

13 files changed

+815
-64
lines changed

13 files changed

+815
-64
lines changed

mlx/backend/cuda/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ target_sources(
3535
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
3636
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
3737
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
38+
${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
3839
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
3940
${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
4041
${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
@@ -67,6 +68,11 @@ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
6768
target_compile_options(mlx
6869
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
6970

71+
# Enable calling host constexpr functions from device. This is needed because
72+
# the constexpr version of isnan is host only.
73+
target_compile_options(
74+
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
75+
7076
# CUDA 12.8 emits warning #20280-D for copy kernels which is a false positive.
7177
# Explicitly pass this flag to suppress the warning, it is safe to set it to
7278
# true but the warning wouldn't be suppressed.

mlx/backend/cuda/device/binary_ops.cuh

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
// Copyright © 2025 Apple Inc.
22

3-
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
4-
#include "mlx/backend/cuda/device/fp16_math.cuh"
5-
#include "mlx/backend/cuda/device/utils.cuh"
3+
#include "mlx/backend/cuda/device/unary_ops.cuh"
64

7-
#include <cuComplex.h>
85
#include <cuda/std/array>
96

107
namespace mlx::core::cu {
@@ -114,36 +111,38 @@ struct LessEqual {
114111
struct LogAddExp {
115112
template <typename T>
116113
__device__ T operator()(T x, T y) {
117-
if (isnan(x) || isnan(y)) {
118-
return cuda::std::numeric_limits<T>::quiet_NaN();
114+
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
115+
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
116+
isnan(cuCimagf(y))) {
117+
return {
118+
cuda::std::numeric_limits<float>::quiet_NaN(),
119+
cuda::std::numeric_limits<float>::quiet_NaN()};
120+
}
121+
auto max = cuCrealf(x) > cuCrealf(y) ? x : y;
122+
auto min = cuCrealf(x) < cuCrealf(y) ? x : y;
123+
auto min_real = cuCrealf(min);
124+
auto max_real = cuCrealf(max);
125+
if (!isfinite(min_real) && (min_real == max_real)) {
126+
if (min_real < 0) {
127+
return min;
128+
} else {
129+
return Log{}(Exp{}(min) + Exp{}(max));
130+
}
131+
} else {
132+
return Log1p{}(Exp{}(min - max)) + max;
133+
}
134+
} else {
135+
if (isnan(x) || isnan(y)) {
136+
return cuda::std::numeric_limits<T>::quiet_NaN();
137+
}
138+
T maxval = max(x, y);
139+
T minval = min(x, y);
140+
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
141+
maxval == cuda::std::numeric_limits<T>::infinity())
142+
? maxval
143+
: T(float(maxval) + log1p(expf(minval - maxval)));
119144
}
120-
T maxval = max(x, y);
121-
T minval = min(x, y);
122-
return (minval == -cuda::std::numeric_limits<T>::infinity() ||
123-
maxval == cuda::std::numeric_limits<T>::infinity())
124-
? maxval
125-
: T(float(maxval) + log1p(expf(minval - maxval)));
126145
};
127-
128-
__device__ cuComplex operator()(cuComplex x, cuComplex y) {
129-
if (isnan(cuCrealf(x)) || isnan(cuCimagf(x)) || isnan(cuCrealf(y)) ||
130-
isnan(cuCimagf(y))) {
131-
return {
132-
cuda::std::numeric_limits<float>::quiet_NaN(),
133-
cuda::std::numeric_limits<float>::quiet_NaN()};
134-
}
135-
float inf = cuda::std::numeric_limits<float>::infinity();
136-
auto maxval = x > y ? x : y;
137-
auto minval = x < y ? x : y;
138-
if (cuCrealf(minval) == -inf || cuCrealf(maxval) == inf)
139-
return maxval;
140-
float m = exp(cuCrealf(minval) - cuCrealf(maxval));
141-
cuComplex dexp{
142-
m * cos(cuCimagf(minval) - cuCimagf(maxval)),
143-
m * sin(cuCimagf(minval) - cuCimagf(maxval)),
144-
};
145-
return maxval + log1p(dexp);
146-
}
147146
};
148147

149148
struct Maximum {

mlx/backend/cuda/device/cexpf.cuh

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright © 2025 Apple Inc.
2+
// Copyright © 2008-2013 NVIDIA Corporation
3+
// Copyright © 2013 Filipe RNC Maia
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
//
17+
// Forked from
18+
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
19+
20+
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
21+
// can not be used in JIT.
22+
23+
#pragma once
24+
25+
#include <cuComplex.h>
26+
#include <cuda/std/cstdint>
27+
28+
namespace mlx::core::cu::detail {
29+
30+
using ieee_float_shape_type = union {
31+
float value;
32+
uint32_t word;
33+
};
34+
35+
inline __device__ void get_float_word(uint32_t& i, float d) {
36+
ieee_float_shape_type gf_u;
37+
gf_u.value = (d);
38+
(i) = gf_u.word;
39+
}
40+
41+
inline __device__ void get_float_word(int32_t& i, float d) {
42+
ieee_float_shape_type gf_u;
43+
gf_u.value = (d);
44+
(i) = gf_u.word;
45+
}
46+
47+
inline __device__ void set_float_word(float& d, uint32_t i) {
48+
ieee_float_shape_type sf_u;
49+
sf_u.word = (i);
50+
(d) = sf_u.value;
51+
}
52+
53+
inline __device__ float frexp_expf(float x, int* expt) {
54+
const uint32_t k = 235;
55+
const float kln2 = 162.88958740F;
56+
57+
float exp_x;
58+
uint32_t hx;
59+
60+
exp_x = expf(x - kln2);
61+
get_float_word(hx, exp_x);
62+
*expt = (hx >> 23) - (0x7f + 127) + k;
63+
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
64+
return exp_x;
65+
}
66+
67+
inline __device__ cuComplex ldexp_cexpf(cuComplex z, int expt) {
68+
float x, y, exp_x, scale1, scale2;
69+
int ex_expt, half_expt;
70+
71+
x = cuCrealf(z);
72+
y = cuCimagf(z);
73+
exp_x = frexp_expf(x, &ex_expt);
74+
expt += ex_expt;
75+
76+
half_expt = expt / 2;
77+
set_float_word(scale1, (0x7f + half_expt) << 23);
78+
half_expt = expt - half_expt;
79+
set_float_word(scale2, (0x7f + half_expt) << 23);
80+
81+
return cuComplex{
82+
cosf(y) * exp_x * scale1 * scale2, sinf(y) * exp_x * scale1 * scale2};
83+
}
84+
85+
inline __device__ cuComplex cexpf(const cuComplex& z) {
86+
float x, y, exp_x;
87+
uint32_t hx, hy;
88+
89+
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
90+
91+
x = cuCrealf(z);
92+
y = cuCimagf(z);
93+
94+
get_float_word(hy, y);
95+
hy &= 0x7fffffff;
96+
97+
/* cexp(x + I 0) = exp(x) + I 0 */
98+
if (hy == 0) {
99+
return cuComplex{expf(x), y};
100+
}
101+
get_float_word(hx, x);
102+
/* cexp(0 + I y) = cos(y) + I sin(y) */
103+
if ((hx & 0x7fffffff) == 0) {
104+
return cuComplex{cosf(y), sinf(y)};
105+
}
106+
if (hy >= 0x7f800000) {
107+
if ((hx & 0x7fffffff) != 0x7f800000) {
108+
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
109+
return cuComplex{y - y, y - y};
110+
} else if (hx & 0x80000000) {
111+
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
112+
return cuComplex{0.0, 0.0};
113+
} else {
114+
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
115+
return cuComplex{x, y - y};
116+
}
117+
}
118+
119+
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
120+
/*
121+
* x is between 88.7 and 192, so we must scale to avoid
122+
* overflow in expf(x).
123+
*/
124+
return ldexp_cexpf(z, 0);
125+
} else {
126+
/*
127+
* Cases covered here:
128+
* - x < exp_ovfl and exp(x) won't overflow (common case)
129+
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
130+
* - x = +-Inf (generated by exp())
131+
* - x = NaN (spurious inexact exception from y)
132+
*/
133+
exp_x = expf(x);
134+
return cuComplex{exp_x * cosf(y), exp_x * sinf(y)};
135+
}
136+
}
137+
138+
} // namespace mlx::core::cu::detail

mlx/backend/cuda/device/unary_ops.cuh

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#pragma once
44

5+
#include "mlx/backend/cuda/device/cexpf.cuh"
6+
#include "mlx/backend/cuda/device/cucomplex_math.cuh"
57
#include "mlx/backend/cuda/device/fp16_math.cuh"
68
#include "mlx/backend/cuda/device/utils.cuh"
79

@@ -150,8 +152,7 @@ struct Exp {
150152
template <typename T>
151153
__device__ T operator()(T x) {
152154
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
153-
auto m = exp(cuCrealf(x));
154-
return {m * cos(cuCimagf(x)), m * sinh(cuCimagf(x))};
155+
return detail::cexpf(x);
155156
} else {
156157
return exp(x);
157158
}
@@ -228,8 +229,25 @@ struct Log10 {
228229

229230
struct Log1p {
230231
template <typename T>
231-
__device__ T operator()(T x) {
232-
return log1p(x);
232+
__device__ T operator()(T z) {
233+
if constexpr (cuda::std::is_same_v<T, cuComplex>) {
234+
float x = cuCrealf(z);
235+
float y = cuCimagf(z);
236+
float zabs = cuCrealf(Abs{}(z));
237+
float theta = atan2f(y, x + 1);
238+
if (zabs < 0.5f) {
239+
float r = x * (2 + x) + y * y;
240+
if (r == 0) { // handle underflow
241+
return {x, theta};
242+
}
243+
return {0.5f * log1pf(r), theta};
244+
} else {
245+
float z0 = hypotf(x + 1, y);
246+
return {logf(z0), theta};
247+
}
248+
} else {
249+
return log1p(z);
250+
}
233251
}
234252
};
235253

@@ -387,19 +405,19 @@ struct Tanh {
387405
}
388406
};
389407

390-
__device__ cuComplex ArcCos::operator()(cuComplex x) {
408+
inline __device__ cuComplex ArcCos::operator()(cuComplex x) {
391409
auto i = cuComplex{0.0, 1.0};
392410
auto y = Log{}(x + i * Sqrt{}(1.0 - x * x));
393411
return {cuCimagf(y), -cuCrealf(y)};
394412
};
395413

396-
__device__ cuComplex ArcSin::operator()(cuComplex x) {
414+
inline __device__ cuComplex ArcSin::operator()(cuComplex x) {
397415
auto i = cuComplex{0.0f, 1.0f};
398416
auto y = Log{}(i * x + Sqrt{}(1.0f - x * x));
399417
return {cuCimagf(y), -cuCrealf(y)};
400418
};
401419

402-
__device__ cuComplex ArcTan::operator()(cuComplex x) {
420+
inline __device__ cuComplex ArcTan::operator()(cuComplex x) {
403421
auto i = cuComplex{0.0f, 1.0f};
404422
auto ix = i * x;
405423
return (1.0f / cuComplex{0.0f, 2.0f}) * Log{}((1.0f + ix) / (1.0f - ix));

mlx/backend/cuda/device/utils.cuh

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -359,21 +359,4 @@ struct LoopedElemToLoc<1, false, OffsetT> {
359359
}
360360
};
361361

362-
inline __device__ cuComplex log1p(cuComplex in) {
363-
float x = cuCrealf(in);
364-
float y = cuCimagf(in);
365-
float zabs = sqrt(x * x + y * y);
366-
float theta = atan2f(y, x + 1);
367-
if (zabs < 0.5f) {
368-
float r = x * (2 + x) + y * y;
369-
if (r == 0) { // handle underflow
370-
return {x, theta};
371-
}
372-
return {0.5f * log1pf(r), theta};
373-
} else {
374-
auto z0 = sqrt((x + 1) * (x + 1) + y * y);
375-
return {log(z0), theta};
376-
}
377-
}
378-
379362
} // namespace mlx::core::cu

mlx/backend/cuda/jit_module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ constexpr const char* g_include_names[] = {
161161
INCLUDE_PREFIX "atomic_ops.cuh",
162162
INCLUDE_PREFIX "binary_ops.cuh",
163163
INCLUDE_PREFIX "cast_op.cuh",
164+
INCLUDE_PREFIX "cexpf.cuh",
164165
INCLUDE_PREFIX "config.h",
165166
INCLUDE_PREFIX "cucomplex_math.cuh",
166167
INCLUDE_PREFIX "fp16_math.cuh",
@@ -177,6 +178,7 @@ constexpr const char* g_headers[] = {
177178
jit_source_atomic_ops,
178179
jit_source_binary_ops,
179180
jit_source_cast_op,
181+
jit_source_cexpf,
180182
jit_source_config,
181183
jit_source_cucomplex_math,
182184
jit_source_fp16_math,

mlx/backend/cuda/primitives.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ NO_GPU(Load)
8282
NO_GPU_MULTI(LUF)
8383
NO_GPU_MULTI(QRF)
8484
NO_GPU(QuantizedMatmul)
85-
NO_GPU(Scan)
8685
NO_GPU(SegmentedMM)
8786
NO_GPU_MULTI(SVD)
8887
NO_GPU(Inverse)

mlx/backend/cuda/reduce/reduce_utils.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <numeric>
66

7+
#include "mlx/backend/common/utils.h"
78
#include "mlx/backend/cuda/device/utils.cuh"
89

910
#include <cooperative_groups.h>

0 commit comments

Comments
 (0)