Skip to content

Commit e9a2190

Browse files
committed
Use cexpf in Metal
1 parent 3564913 commit e9a2190

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

mlx/backend/metal/kernels/cexpf.h

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright © 2025 Apple Inc.
2+
// Copyright © 2008-2013 NVIDIA Corporation
3+
// Copyright © 2013 Filipe RNC Maia
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
//
17+
// Forked from
18+
// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h
19+
20+
// TODO: We should use thrust::exp but the thrust header in old CUDA versions
21+
// can not be used in JIT.
22+
23+
#pragma once
24+
25+
#include <metal_math>
26+
27+
using ieee_float_shape_type = union {
28+
float value;
29+
uint32_t word;
30+
};
31+
32+
inline void get_float_word(thread uint32_t& i, float d) {
33+
ieee_float_shape_type gf_u;
34+
gf_u.value = (d);
35+
(i) = gf_u.word;
36+
}
37+
38+
inline void get_float_word(thread int32_t& i, float d) {
39+
ieee_float_shape_type gf_u;
40+
gf_u.value = (d);
41+
(i) = gf_u.word;
42+
}
43+
44+
inline void set_float_word(thread float& d, uint32_t i) {
45+
ieee_float_shape_type sf_u;
46+
sf_u.word = (i);
47+
(d) = sf_u.value;
48+
}
49+
50+
inline float frexp_expf(float x, thread int* expt) {
51+
const uint32_t k = 235;
52+
const float kln2 = 162.88958740F;
53+
54+
float exp_x;
55+
uint32_t hx;
56+
57+
exp_x = metal::exp(x - kln2);
58+
get_float_word(hx, exp_x);
59+
*expt = (hx >> 23) - (0x7f + 127) + k;
60+
set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23));
61+
return exp_x;
62+
}
63+
64+
inline complex64_t ldexp_cexpf(complex64_t z, int expt) {
65+
float x, y, exp_x, scale1, scale2;
66+
int ex_expt, half_expt;
67+
68+
x = z.real;
69+
y = z.imag;
70+
exp_x = frexp_expf(x, &ex_expt);
71+
expt += ex_expt;
72+
73+
half_expt = expt / 2;
74+
set_float_word(scale1, (0x7f + half_expt) << 23);
75+
half_expt = expt - half_expt;
76+
set_float_word(scale2, (0x7f + half_expt) << 23);
77+
78+
return complex64_t{
79+
metal::cos(y) * exp_x * scale1 * scale2,
80+
metal::sin(y) * exp_x * scale1 * scale2};
81+
}
82+
83+
inline complex64_t cexpf(const thread complex64_t& z) {
84+
float x, y, exp_x;
85+
uint32_t hx, hy;
86+
87+
const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074;
88+
89+
x = z.real;
90+
y = z.imag;
91+
92+
get_float_word(hy, y);
93+
hy &= 0x7fffffff;
94+
95+
/* cexp(x + I 0) = exp(x) + I 0 */
96+
if (hy == 0) {
97+
return complex64_t{metal::exp(x), y};
98+
}
99+
get_float_word(hx, x);
100+
/* cexp(0 + I y) = cos(y) + I sin(y) */
101+
if ((hx & 0x7fffffff) == 0) {
102+
return complex64_t{metal::cos(y), metal::sin(y)};
103+
}
104+
if (hy >= 0x7f800000) {
105+
if ((hx & 0x7fffffff) != 0x7f800000) {
106+
/* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */
107+
return complex64_t{y - y, y - y};
108+
} else if (hx & 0x80000000) {
109+
/* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */
110+
return complex64_t{0.0, 0.0};
111+
} else {
112+
/* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */
113+
return complex64_t{x, y - y};
114+
}
115+
}
116+
117+
if (hx >= exp_ovfl && hx <= cexp_ovfl) {
118+
/*
119+
* x is between 88.7 and 192, so we must scale to avoid
120+
* overflow in expf(x).
121+
*/
122+
return ldexp_cexpf(z, 0);
123+
} else {
124+
/*
125+
* Cases covered here:
126+
* - x < exp_ovfl and exp(x) won't overflow (common case)
127+
* - x > cexp_ovfl, so exp(x) * s overflows for all s > 0
128+
* - x = +-Inf (generated by exp())
129+
* - x = NaN (spurious inexact exception from y)
130+
*/
131+
exp_x = metal::exp(x);
132+
return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)};
133+
}
134+
}

mlx/backend/metal/kernels/unary_ops.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <metal_integer>
66
#include <metal_math>
77

8+
#include "mlx/backend/metal/kernels/cexpf.h"
89
#include "mlx/backend/metal/kernels/erf.h"
910
#include "mlx/backend/metal/kernels/expm1f.h"
1011

@@ -178,8 +179,7 @@ struct Exp {
178179
return metal::precise::exp(x);
179180
};
180181
complex64_t operator()(complex64_t x) {
181-
auto m = metal::precise::exp(x.real);
182-
return {m * metal::precise::cos(x.imag), m * metal::precise::sin(x.imag)};
182+
return cexpf(x);
183183
}
184184
};
185185

0 commit comments

Comments
 (0)