Skip to content

Commit 48fa6ae

Browse files
committed
Fix dtype_to_cuda_type
1 parent 4aca093 commit 48fa6ae

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

mlx/backend/cuda/utils.cpp

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,38 @@ void check_cuda_error(const char* name, cudaError_t err) {
2525
}
2626

2727
const char* dtype_to_cuda_type(const Dtype& dtype) {
28-
if (dtype == float16) {
29-
return "__half";
28+
switch (dtype) {
29+
case bool_:
30+
return "bool";
31+
case int8:
32+
return "int8_t";
33+
case int16:
34+
return "int16_t";
35+
case int32:
36+
return "int32_t";
37+
case int64:
38+
return "int64_t";
39+
case uint8:
40+
return "uint8_t";
41+
case uint16:
42+
return "uint16_t";
43+
case uint32:
44+
return "uint32_t";
45+
case uint64:
46+
return "uint64_t";
47+
case float16:
48+
return "__half";
49+
case bfloat16:
50+
return "__nv_bfloat16";
51+
case float32:
52+
return "float";
53+
case float64:
54+
return "double";
55+
case complex64:
56+
return "cuComplex";
57+
default:
58+
return "unknown";
3059
}
31-
if (dtype == bfloat16) {
32-
return "__nv_bfloat16";
33-
}
34-
if (dtype == complex64) {
35-
return "cuComplex";
36-
}
37-
return dtype_to_string(dtype);
3860
}
3961

4062
} // namespace mlx::core

mlx/dtype_utils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ const char* dtype_to_string(Dtype arg) {
3434
return "float64";
3535
case complex64:
3636
return "complex64";
37+
default:
38+
return "unknown";
3739
}
38-
return "unknown";
3940
}
4041

4142
} // namespace mlx::core

0 commit comments

Comments
 (0)