Skip to content

Commit 8e7e0a1

Browse files
authored
Merge branch 'ml-explore:main' into junpeiz/add-export-tests
2 parents e64f29c + b405591 commit 8e7e0a1

36 files changed

+795
-592
lines changed

.circleci/config.yml

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -81,30 +81,32 @@ jobs:
8181
export DEBIAN_FRONTEND=noninteractive
8282
export NEEDRESTART_MODE=a
8383
sudo apt-get update
84-
sudo apt-get upgrade -y
85-
pip install --upgrade cmake
8684
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
8785
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
86+
curl -LsSf https://astral.sh/uv/install.sh | sh
8887
- run:
8988
name: Install Python package
9089
command: |
91-
pip install -e ".[dev]"
90+
uv venv
91+
uv pip install cmake
92+
uv pip install -e ".[dev]" -v
9293
- run:
9394
name: Generate package stubs
9495
command: |
95-
echo "stubs"
96-
pip install typing_extensions
97-
python setup.py generate_stubs
96+
uv pip install typing_extensions
97+
uv run --no-project setup.py generate_stubs
9898
- run:
9999
name: Run Python tests
100100
command: |
101+
source .venv/bin/activate
101102
python -m unittest discover python/tests -v
102103
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
103104
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
104105
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
105106
- run:
106107
name: Build CPP only
107108
command: |
109+
source .venv/bin/activate
108110
mkdir -p build && cd build
109111
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
110112
make -j `nproc`
@@ -130,33 +132,30 @@ jobs:
130132
- run:
131133
name: Install dependencies
132134
command: |
133-
brew install [email protected]
134-
brew install openmpi
135-
python3.9 -m venv env
136-
source env/bin/activate
137-
pip install --upgrade pip
138-
pip install --upgrade cmake
139-
pip install nanobind==2.4.0
140-
pip install numpy
141-
pip install torch
142-
pip install tensorflow
143-
pip install unittest-xml-reporting
135+
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
136+
brew install openmpi uv
144137
- run:
145138
name: Install Python package
146139
command: |
147-
source env/bin/activate
140+
uv venv --python 3.9
141+
uv pip install \
142+
nanobind==2.4.0 \
143+
cmake \
144+
numpy \
145+
torch \
146+
tensorflow \
147+
unittest-xml-reporting
148148
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
149-
pip install -e . -v
149+
uv pip install -e . -v
150150
- run:
151151
name: Generate package stubs
152152
command: |
153-
source env/bin/activate
154-
pip install typing_extensions
155-
python setup.py generate_stubs
153+
uv pip install typing_extensions
154+
uv run --no-project setup.py generate_stubs
156155
- run:
157156
name: Run Python tests
158157
command: |
159-
source env/bin/activate
158+
source .venv/bin/activate
160159
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
161160
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
162161
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
@@ -165,16 +164,15 @@ jobs:
165164
- run:
166165
name: Build example extension
167166
command: |
168-
source env/bin/activate
169167
cd examples/extensions
170-
pip install -r requirements.txt
171-
python setup.py build_ext -j8
168+
uv pip install -r requirements.txt
169+
uv run --no-project setup.py build_ext -j8
172170
- store_test_results:
173171
path: test-results
174172
- run:
175173
name: Build CPP only
176174
command: |
177-
source env/bin/activate
175+
source .venv/bin/activate
178176
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
179177
- run:
180178
name: Run CPP tests
@@ -183,7 +181,7 @@ jobs:
183181
- run:
184182
name: Build small binary
185183
command: |
186-
source env/bin/activate
184+
source .venv/bin/activate
187185
cd build/
188186
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
189187
-DBUILD_SHARED_LIBS=ON \
@@ -195,12 +193,13 @@ jobs:
195193
- run:
196194
name: Run Python tests with JIT
197195
command: |
198-
source env/bin/activate
199196
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
200-
pip install -e . -v
197+
uv pip install -e .
201198
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
202199
METAL_DEBUG_ERROR_MODE=0 \
203-
python -m xmlrunner discover -v python/tests -o test-results/gpu_jit
200+
uv run --no-project python -m xmlrunner discover \
201+
-v python/tests \
202+
-o test-results/gpu_jit
204203
205204
cuda_build_and_test:
206205
parameters:
@@ -212,22 +211,42 @@ jobs:
212211
resource_class: gpu.nvidia.small.gen2
213212
steps:
214213
- checkout
214+
- restore_cache:
215+
keys:
216+
- cuda-<< parameters.image_date >>-{{ arch }}-
215217
- run:
216-
name: Install Python package
218+
name: Install dependencies
217219
command: |
218220
sudo apt-get update
219221
sudo apt-get install libcudnn9-dev-cuda-12
220222
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
221-
python3 -m venv env
222-
source env/bin/activate
223+
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
224+
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
225+
rm -rf ccache-4.11.3-linux-x86_64
226+
curl -LsSf https://astral.sh/uv/install.sh | sh
227+
- run:
228+
name: Install Python package
229+
command: |
230+
uv venv
223231
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
224-
pip install -e ".[dev]"
232+
uv pip install -e ".[dev]" -v
225233
- run:
226234
name: Run Python tests
227235
command: |
228-
source env/bin/activate
236+
source .venv/bin/activate
229237
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
230238
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
239+
- run:
240+
name: CCache report
241+
command: |
242+
ccache --show-stats
243+
ccache --zero-stats
244+
ccache --max-size 400MB
245+
ccache --cleanup
246+
- save_cache:
247+
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
248+
paths:
249+
- /home/circleci/.cache/ccache
231250

232251
build_release:
233252
parameters:
@@ -323,14 +342,10 @@ jobs:
323342
export DEBIAN_FRONTEND=noninteractive
324343
export NEEDRESTART_MODE=a
325344
sudo apt-get update
326-
sudo apt-get upgrade -y
327345
TZ=Etc/UTC sudo apt-get -y install tzdata
328-
sudo apt-get install -y apt-utils
329-
sudo apt-get install -y software-properties-common
330346
sudo add-apt-repository -y ppa:deadsnakes/ppa
331347
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
332348
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
333-
sudo apt-get install -y build-essential git
334349
$PYTHON -m venv env
335350
source env/bin/activate
336351
pip install --upgrade pip
@@ -555,6 +570,9 @@ workflows:
555570
requires: [ hold ]
556571
- cuda_build_and_test:
557572
requires: [ hold ]
573+
matrix:
574+
parameters:
575+
image_date: ["2023.11.1", "2025.05.1"]
558576
nightly_build:
559577
when:
560578
and:

CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
4141
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
4242
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
4343
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
44+
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
4445
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
4546

4647
# --------------------- Processor tests -------------------------
@@ -68,6 +69,15 @@ else()
6869
set(MLX_BUILD_METAL OFF)
6970
endif()
7071

72+
if(MLX_USE_CCACHE)
73+
find_program(CCACHE_PROGRAM ccache)
74+
if(CCACHE_PROGRAM)
75+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
76+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
77+
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
78+
endif()
79+
endif()
80+
7181
# ----------------------------- Lib -----------------------------
7282

7383
include(FetchContent)

mlx/backend/cuda/CMakeLists.txt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
target_sources(
77
mlx
88
PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
9+
${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
910
${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
1011
${CMAKE_CURRENT_SOURCE_DIR}/binary.cu
1112
${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
@@ -29,7 +30,7 @@ target_sources(
2930
${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
3031
${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
3132
${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
32-
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
33+
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
3334
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
3435
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
3536
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
@@ -45,7 +46,8 @@ target_sources(
4546
${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
4647
${CMAKE_CURRENT_SOURCE_DIR}/unary.cu
4748
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
48-
${CMAKE_CURRENT_SOURCE_DIR}/quantized.cu
49+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
50+
${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
4951
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
5052

5153
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
@@ -105,11 +107,11 @@ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
105107
mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
106108
endif()
107109

108-
# Compute capability 7 is required for synchronization between CPU/GPU with
109-
# managed memory. TODO: Add more architectures for potential performance gain.
110-
set(MLX_CUDA_ARCHITECTURES
111-
"70;80"
112-
CACHE STRING "CUDA architectures")
110+
# Compute capability >= 7.0 is required for synchronization between CPU/GPU with
111+
# managed memory.
112+
if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
113+
set(MLX_CUDA_ARCHITECTURES "native")
114+
endif()
113115
message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
114116
set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
115117
"${MLX_CUDA_ARCHITECTURES}")

mlx/backend/cuda/arange.cu

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright © 2025 Apple Inc.
2+
3+
#include "mlx/backend/cuda/device.h"
4+
#include "mlx/backend/cuda/device/fp16_math.cuh"
5+
#include "mlx/backend/cuda/kernel_utils.cuh"
6+
#include "mlx/dtype_utils.h"
7+
#include "mlx/primitives.h"
8+
9+
#include <nvtx3/nvtx3.hpp>
10+
#include <thrust/device_ptr.h>
11+
#include <thrust/transform.h>
12+
13+
namespace mlx::core {
14+
15+
namespace cu {
16+
17+
template <typename T>
18+
struct Arange {
19+
const T start;
20+
const T step;
21+
22+
__device__ T operator()(uint32_t i) const {
23+
return start + i * step;
24+
}
25+
};
26+
27+
} // namespace cu
28+
29+
void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
30+
nvtx3::scoped_range r("Arange::eval_gpu");
31+
if (out.size() == 0) {
32+
return;
33+
}
34+
out.set_data(allocator::malloc(out.nbytes()));
35+
36+
auto& encoder = cu::get_command_encoder(stream());
37+
encoder.set_output_array(out);
38+
39+
auto capture = encoder.capture_context();
40+
dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
41+
using CTYPE = MLX_GET_TYPE(type_tag);
42+
using OutType = cuda_type_t<CTYPE>;
43+
CTYPE step =
44+
static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_);
45+
thrust::transform(
46+
cu::thrust_policy(encoder.stream()),
47+
thrust::counting_iterator<uint32_t>(0),
48+
thrust::counting_iterator<uint32_t>(out.data_size()),
49+
thrust::device_pointer_cast(out.data<OutType>()),
50+
cu::Arange<OutType>{
51+
static_cast<OutType>(start_), static_cast<OutType>(step)});
52+
});
53+
}
54+
55+
} // namespace mlx::core

mlx/backend/cuda/arg_reduce.cu

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@ struct ArgMin {
4444
}
4545

4646
template <int N>
47-
__device__ IndexValPair<T>
48-
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
47+
__device__ IndexValPair<T> reduce_many(
48+
IndexValPair<T> best,
49+
const AlignedVector<T, N>& vals,
50+
uint32_t offset) {
51+
#pragma unroll
4952
for (int i = 0; i < N; i++) {
5053
if (vals[i] < best.val) {
5154
best.val = vals[i];
@@ -74,8 +77,11 @@ struct ArgMax {
7477
}
7578

7679
template <int N>
77-
__device__ IndexValPair<T>
78-
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
80+
__device__ IndexValPair<T> reduce_many(
81+
IndexValPair<T> best,
82+
const AlignedVector<T, N>& vals,
83+
uint32_t offset) {
84+
#pragma unroll
7985
for (int i = 0; i < N; i++) {
8086
if (vals[i] > best.val) {
8187
best.val = vals[i];
@@ -106,16 +112,15 @@ __global__ void arg_reduce_general(
106112

107113
int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
108114
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
115+
in += in_idx;
109116

110117
Op op;
111118
T init = op.init();
112119
IndexValPair<T> best{0, init};
113120

114121
for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
115-
T vals[N_READS];
116122
auto tid = r * BLOCK_DIM + block.thread_index().x;
117-
cub::LoadDirectBlocked(
118-
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
123+
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
119124
best = op.reduce_many(best, vals, tid * N_READS);
120125
}
121126

0 commit comments

Comments
 (0)