Skip to content

Commit c35f4d0

Browse files
authored
start cuda circle config (#2256)
* rebase * fix metal kernel linking issue on cuda * start cuda circle config
1 parent 8590c09 commit c35f4d0

File tree

14 files changed

+101
-26
lines changed

14 files changed

+101
-26
lines changed

.circleci/config.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,29 @@ jobs:
212212
METAL_DEBUG_ERROR_MODE=0 \
213213
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
214214
215+
cuda_build_and_test:
216+
machine:
217+
image: linux-cuda-12:default
218+
resource_class: gpu.nvidia.small.gen2
219+
steps:
220+
- checkout
221+
- run:
222+
name: Install Python package
223+
command: |
224+
sudo apt-get update
225+
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
226+
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
227+
python -m venv env
228+
source env/bin/activate
229+
CMAKE_BUILD_PARALLEL_LEVEL=`nproc` \
230+
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
231+
pip install -e ".[dev]"
232+
- run:
233+
name: Run Python tests
234+
command: |
235+
source env/bin/activate
236+
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
237+
215238
build_release:
216239
parameters:
217240
python_version:
@@ -348,6 +371,7 @@ workflows:
348371
parameters:
349372
macosx_deployment_target: ["13.5", "14.0"]
350373
- linux_build_and_test
374+
- cuda_build_and_test
351375
- build_documentation
352376

353377
build_pypi_release:
@@ -455,6 +479,8 @@ workflows:
455479
macosx_deployment_target: ["13.5", "14.0"]
456480
- linux_build_and_test:
457481
requires: [ hold ]
482+
- cuda_build_and_test:
483+
requires: [ hold ]
458484
nightly_build:
459485
when:
460486
and:

mlx/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ endif()
5555

5656
if(MLX_BUILD_CUDA)
5757
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda)
58+
else()
59+
target_sources(mlx
60+
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/backend/cuda/no_cuda.cpp)
5861
endif()
5962

6063
if(MLX_BUILD_METAL OR MLX_BUILD_CUDA)

mlx/backend/cuda/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ target_sources(
1212
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
1313
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
1414
${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
15+
${CMAKE_CURRENT_SOURCE_DIR}/cuda.cpp
1516
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
1617
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
1718
${CMAKE_CURRENT_SOURCE_DIR}/event.cu

mlx/backend/cuda/cuda.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/cuda.h"
4+
5+
namespace mlx::core::cu {
6+
7+
bool is_available() {
8+
return true;
9+
}
10+
11+
} // namespace mlx::core::cu

mlx/backend/cuda/cuda.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#pragma once
4+
5+
namespace mlx::core::cu {
6+
7+
/* Check if the CUDA backend is available. */
8+
bool is_available();
9+
10+
} // namespace mlx::core::cu

mlx/backend/cuda/no_cuda.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/cuda.h"
4+
5+
namespace mlx::core::cu {
6+
7+
bool is_available() {
8+
return false;
9+
}
10+
11+
} // namespace mlx::core::cu

mlx/backend/metal/no_metal.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
#include <stdexcept>
44

55
#include "mlx/backend/metal/metal.h"
6+
#include "mlx/fast.h"
67

7-
namespace mlx::core::metal {
8+
namespace mlx::core {
9+
10+
namespace metal {
811

912
bool is_available() {
1013
return false;
@@ -19,4 +22,21 @@ device_info() {
1922
"[metal::device_info] Cannot get device info without metal backend");
2023
};
2124

22-
} // namespace mlx::core::metal
25+
} // namespace metal
26+
27+
namespace fast {
28+
29+
MetalKernelFunction metal_kernel(
30+
const std::string&,
31+
const std::vector<std::string>&,
32+
const std::vector<std::string>&,
33+
const std::string&,
34+
const std::string&,
35+
bool ensure_row_contiguous,
36+
bool atomic_outputs) {
37+
throw std::runtime_error("[metal_kernel] No GPU back-end.");
38+
}
39+
40+
} // namespace fast
41+
42+
} // namespace mlx::core

mlx/backend/no_gpu/primitives.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include "mlx/primitives.h"
44
#include "mlx/distributed/primitives.h"
5-
#include "mlx/fast.h"
65
#include "mlx/fast_primitives.h"
76

87
#define NO_GPU_MULTI(func) \
@@ -156,18 +155,6 @@ NO_GPU_USE_FALLBACK(RoPE)
156155
NO_GPU(ScaledDotProductAttention)
157156
NO_GPU_MULTI(AffineQuantize)
158157
NO_GPU_MULTI(CustomKernel)
159-
160-
MetalKernelFunction metal_kernel(
161-
const std::string&,
162-
const std::vector<std::string>&,
163-
const std::vector<std::string>&,
164-
const std::string&,
165-
const std::string&,
166-
bool ensure_row_contiguous,
167-
bool atomic_outputs) {
168-
throw std::runtime_error("[metal_kernel] No GPU back-end.");
169-
}
170-
171158
} // namespace fast
172159

173160
namespace distributed {

mlx/mlx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "mlx/array.h"
6+
#include "mlx/backend/cuda/cuda.h"
67
#include "mlx/backend/metal/metal.h"
78
#include "mlx/compile.h"
89
#include "mlx/device.h"

python/src/array.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
#include "python/src/indexing.h"
1818
#include "python/src/utils.h"
1919

20-
#include "mlx/device.h"
21-
#include "mlx/ops.h"
22-
#include "mlx/transforms.h"
23-
#include "mlx/utils.h"
20+
#include "mlx/mlx.h"
2421

2522
namespace mx = mlx::core;
2623
namespace nb = nanobind;
@@ -461,9 +458,12 @@ void init_array(nb::module_& m) {
461458
.def(
462459
"__dlpack_device__",
463460
[](const mx::array& a) {
461+
// See
462+
// https://github.com/dmlc/dlpack/blob/5c210da409e7f1e51ddf445134a4376fdbd70d7d/include/dlpack/dlpack.h#L74
464463
if (mx::metal::is_available()) {
465-
// Metal device is available
466464
return nb::make_tuple(8, 0);
465+
} else if (mx::cu::is_available()) {
466+
return nb::make_tuple(13, 0);
467467
} else {
468468
// CPU device
469469
return nb::make_tuple(1, 0);

0 commit comments

Comments
 (0)