File tree Expand file tree Collapse file tree 14 files changed +101
-26
lines changed Expand file tree Collapse file tree 14 files changed +101
-26
lines changed Original file line number Diff line number Diff line change @@ -212,6 +212,29 @@ jobs:
212
212
METAL_DEBUG_ERROR_MODE=0 \
213
213
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
214
214
215
+ cuda_build_and_test :
216
+ machine :
217
+ image : linux-cuda-12:default
218
+ resource_class : gpu.nvidia.small.gen2
219
+ steps :
220
+ - checkout
221
+ - run :
222
+ name : Install Python package
223
+ command : |
224
+ sudo apt-get update
225
+ sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
226
+ sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
227
+ python -m venv env
228
+ source env/bin/activate
229
+ CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
230
+ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
231
+ pip install -e ".[dev]"
232
+ - run :
233
+ name : Run Python tests
234
+ command : |
235
+ source env/bin/activate
236
+ LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
237
+
215
238
build_release :
216
239
parameters :
217
240
python_version :
@@ -348,6 +371,7 @@ workflows:
348
371
parameters :
349
372
macosx_deployment_target : ["13.5", "14.0"]
350
373
- linux_build_and_test
374
+ - cuda_build_and_test
351
375
- build_documentation
352
376
353
377
build_pypi_release :
@@ -455,6 +479,8 @@ workflows:
455
479
macosx_deployment_target : ["13.5", "14.0"]
456
480
- linux_build_and_test :
457
481
requires : [ hold ]
482
+ - cuda_build_and_test :
483
+ requires : [ hold ]
458
484
nightly_build :
459
485
when :
460
486
and :
Original file line number Diff line number Diff line change @@ -55,6 +55,9 @@ endif()
55
55
56
56
if (MLX_BUILD_CUDA )
57
57
add_subdirectory (${CMAKE_CURRENT_SOURCE_DIR} /backend/cuda )
58
+ else ()
59
+ target_sources (mlx
60
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR} /backend/cuda/no_cuda.cpp )
58
61
endif ()
59
62
60
63
if (MLX_BUILD_METAL OR MLX_BUILD_CUDA )
Original file line number Diff line number Diff line change @@ -12,6 +12,7 @@ target_sources(
12
12
${CMAKE_CURRENT_SOURCE_DIR} /copy/copy_general.cu
13
13
${CMAKE_CURRENT_SOURCE_DIR} /copy/copy_general_dynamic.cu
14
14
${CMAKE_CURRENT_SOURCE_DIR} /copy/copy_general_input.cu
15
+ ${CMAKE_CURRENT_SOURCE_DIR} /cuda.cpp
15
16
${CMAKE_CURRENT_SOURCE_DIR} /device.cpp
16
17
${CMAKE_CURRENT_SOURCE_DIR} /eval.cpp
17
18
${CMAKE_CURRENT_SOURCE_DIR} /event.cu
Original file line number Diff line number Diff line change
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include " mlx/backend/cuda/cuda.h"
4
+
5
+ namespace mlx ::core::cu {
6
+
7
+ bool is_available () {
8
+ return true ;
9
+ }
10
+
11
+ } // namespace mlx::core::cu
Original file line number Diff line number Diff line change
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ namespace mlx ::core::cu {
6
+
7
+ /* Check if the CUDA backend is available. */
8
+ bool is_available ();
9
+
10
+ } // namespace mlx::core::cu
Original file line number Diff line number Diff line change
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include " mlx/backend/cuda/cuda.h"
4
+
5
+ namespace mlx ::core::cu {
6
+
7
+ bool is_available () {
8
+ return false ;
9
+ }
10
+
11
+ } // namespace mlx::core::cu
Original file line number Diff line number Diff line change 3
3
#include < stdexcept>
4
4
5
5
#include " mlx/backend/metal/metal.h"
6
+ #include " mlx/fast.h"
6
7
7
- namespace mlx ::core::metal {
8
+ namespace mlx ::core {
9
+
10
+ namespace metal {
8
11
9
12
bool is_available () {
10
13
return false ;
@@ -19,4 +22,21 @@ device_info() {
19
22
" [metal::device_info] Cannot get device info without metal backend" );
20
23
};
21
24
22
- } // namespace mlx::core::metal
25
+ } // namespace metal
26
+
27
+ namespace fast {
28
+
29
+ MetalKernelFunction metal_kernel (
30
+ const std::string&,
31
+ const std::vector<std::string>&,
32
+ const std::vector<std::string>&,
33
+ const std::string&,
34
+ const std::string&,
35
+ bool ensure_row_contiguous,
36
+ bool atomic_outputs) {
37
+ throw std::runtime_error (" [metal_kernel] No GPU back-end." );
38
+ }
39
+
40
+ } // namespace fast
41
+
42
+ } // namespace mlx::core
Original file line number Diff line number Diff line change 2
2
3
3
#include " mlx/primitives.h"
4
4
#include " mlx/distributed/primitives.h"
5
- #include " mlx/fast.h"
6
5
#include " mlx/fast_primitives.h"
7
6
8
7
#define NO_GPU_MULTI (func ) \
@@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
156
155
NO_GPU (ScaledDotProductAttention)
157
156
NO_GPU_MULTI (AffineQuantize)
158
157
NO_GPU_MULTI (CustomKernel)
159
-
160
- MetalKernelFunction metal_kernel (
161
- const std::string&,
162
- const std::vector<std::string>&,
163
- const std::vector<std::string>&,
164
- const std::string&,
165
- const std::string&,
166
- bool ensure_row_contiguous,
167
- bool atomic_outputs) {
168
- throw std::runtime_error (" [metal_kernel] No GPU back-end." );
169
- }
170
-
171
158
} // namespace fast
172
159
173
160
namespace distributed {
Original file line number Diff line number Diff line change 3
3
#pragma once
4
4
5
5
#include "mlx/array.h"
6
+ #include "mlx/backend/cuda/cuda.h"
6
7
#include "mlx/backend/metal/metal.h"
7
8
#include "mlx/compile.h"
8
9
#include "mlx/device.h"
Original file line number Diff line number Diff line change 17
17
#include " python/src/indexing.h"
18
18
#include " python/src/utils.h"
19
19
20
- #include " mlx/device.h"
21
- #include " mlx/ops.h"
22
- #include " mlx/transforms.h"
23
- #include " mlx/utils.h"
20
+ #include " mlx/mlx.h"
24
21
25
22
namespace mx = mlx::core;
26
23
namespace nb = nanobind;
@@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
461
458
.def (
462
459
" __dlpack_device__" ,
463
460
[](const mx::array& a) {
461
+ // See
462
+ // https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
464
463
if (mx::metal::is_available ()) {
465
- // Metal device is available
466
464
return nb::make_tuple (8 , 0 );
465
+ } else if (mx::cu::is_available ()) {
466
+ return nb::make_tuple (13 , 0 );
467
467
} else {
468
468
// CPU device
469
469
return nb::make_tuple (1 , 0 );
You can’t perform that action at this time.
0 commit comments