Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions benchmarks/python/comparative/bench_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time

import torch
import torch.cuda
import torch.mps


Expand Down Expand Up @@ -44,8 +45,10 @@ def bench(f, *args):


def sync_if_needed(x):
if x.device != torch.device("cpu"):
if x.device == torch.device("mps"):
torch.mps.synchronize()
elif x.device == torch.device("cuda"):
torch.cuda.synchronize()


@torch.no_grad()
Expand Down Expand Up @@ -99,6 +102,14 @@ def reduction(op, axis, x):
sync_if_needed(x)


@torch.no_grad()
def sum_and_add(axis, x, y):
z = x.sum(axis=axis, keepdims=True)
for i in range(50):
z = (z + y).sum(axis=axis, keepdims=True)
sync_if_needed(x)


@torch.no_grad()
def softmax(axis, x):
ys = []
Expand Down Expand Up @@ -340,7 +351,11 @@ def selu(x):
args.axis.pop(0)

torch.set_num_threads(1)
device = "cpu" if args.cpu else "mps"
device = "mps"
if torch.cuda.is_available():
device = "cuda"
if args.cpu:
device = "cpu"

types = args.dtype
if not types:
Expand Down Expand Up @@ -460,5 +475,8 @@ def selu(x):
elif args.benchmark == "selu":
print(bench(selu, x))

elif args.benchmark == "sum_and_add":
print(bench(sum_and_add, axis, *xs))

else:
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
15 changes: 11 additions & 4 deletions mlx/backend/common/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
namespace mlx::core {

std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
Shape shape,
Strides strides,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();

for (int i = axes.size() - 1; i >= 0; i--) {
int a = axes[i];
shape.erase(shape.begin() + a);
Expand All @@ -19,6 +17,15 @@ std::pair<Shape, Strides> shapes_without_reduction_axes(
return std::make_pair(shape, strides);
}

std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes) {
auto shape = x.shape();
auto strides = x.strides();
return shapes_without_reduction_axes(
std::move(shape), std::move(strides), axes);
}

ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
// The data is all there and we are reducing over everything
if (x.size() == x.data_size() && axes.size() == x.ndim() &&
Expand Down
4 changes: 4 additions & 0 deletions mlx/backend/common/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,9 @@ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
const array& x,
const std::vector<int>& axes);
std::pair<Shape, Strides> shapes_without_reduction_axes(
Shape shape,
Strides strides,
const std::vector<int>& axes);

} // namespace mlx::core
3 changes: 2 additions & 1 deletion mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/random.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/reduce/segmented_reduce.cu
${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/binary_two.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ void binary_op_gpu_inplace(
if (ndim <= 3) {
MLX_SWITCH_1_2_3(ndim, NDIM, {
auto kernel =
&cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
cu::binary_g_nd<Op, InType, OutType, IdxT, NDIM>;
auto [num_blocks, block_dims] =
get_launch_args(kernel, out_a, large);
kernel<<<num_blocks, block_dims, 0, stream>>>(
Expand Down
38 changes: 16 additions & 22 deletions mlx/backend/cuda/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,11 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(!axes_.empty());
assert(out.size() != in.size());

out.set_data(allocator::malloc(out.nbytes()));

auto& s = stream();
auto& encoder = cu::get_command_encoder(s);
encoder.set_input_array(in);
encoder.set_output_array(out);

// Fill out with init value.
if (in.size() == 0) {
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(in.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type_, OP, {
using InType = cuda_type_t<CTYPE>;
using OutType = cu::ReduceResult<OP, InType>::type;
thrust::fill_n(
cu::thrust_policy(stream),
thrust::device_pointer_cast(out.data<OutType>()),
out.data_size(),
cu::ReduceInit<OP, InType>::value());
});
});
});
init_reduce(encoder, in, out, reduce_type_);
return;
}

Expand All @@ -51,17 +34,28 @@ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {

// If it is a general reduce then copy the input to a contiguous array and
// recompute the plan.
if (plan.type == GeneralReduce) {
//
// TODO: Instead of copying we can use elem-to-loc to deal with broadcasting
// like we do in Metal. When it comes to broadcasted reduction axes
// some can be ignored eg for min/max.
bool broadcasted = false;
for (int i = 0, j = 0; i < in.ndim() && !broadcasted; i++) {
if (j < axes_.size() && axes_[j] == i) {
j++;
} else {
broadcasted = in.strides(i) == 0;
}
}
if (plan.type == GeneralReduce || broadcasted || !in.flags().contiguous) {
array in_copy(in.shape(), in.dtype(), nullptr, {});
copy_gpu(in, in_copy, CopyType::General, s);
encoder.add_temporary(in_copy);
in = in_copy;
plan = get_reduction_plan(in, axes_);
}

if ((plan.type == ContiguousAllReduce) ||
(plan.type == ContiguousReduce && plan.shape.size() == 1)) {
segmented_reduce(encoder, in, out, reduce_type_, axes_, plan);
if (plan.type == ContiguousAllReduce) {
all_reduce(encoder, in, out, reduce_type_);
return;
}

Expand Down
141 changes: 141 additions & 0 deletions mlx/backend/cuda/reduce/all_reduce.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright © 2025 Apple Inc.

#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/reduce/reduce.cuh"

#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/block/block_load.cuh>

namespace mlx::core {

namespace cu {

namespace cg = cooperative_groups;

template <typename T, typename U, typename ReduceOp, int N = 4>
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) {
// TODO: Process multiple "rows" in each thread
constexpr int M = 1;

auto grid = cg::this_grid();
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);

const U init = cu::ReduceInit<ReduceOp, T>::value();
ReduceOp op;

T vals[N];
U accs[M];
accs[0] = init;

size_t start = grid.block_rank() * block_step;
size_t end = start + block_step;
size_t check = min(end, size);

size_t i = start;
for (; i + block.size() * N <= check; i += block.size() * N) {
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals);
for (int j = 0; j < N; j++) {
accs[0] = op(accs[0], __cast<U, T>(vals[j]));
}
}

if (i < check) {
cub::LoadDirectBlocked(
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init));
for (int i = 0; i < N; i++) {
accs[0] = op(accs[0], __cast<U, T>(vals[i]));
}
}

__shared__ U shared_accumulators[32];
block_reduce(block, warp, accs, shared_accumulators, op, init);

if (block.thread_rank() == 0) {
out[grid.block_rank()] = accs[0];
}
}

} // namespace cu

void all_reduce(
cu::CommandEncoder& encoder,
const array& in,
array& out,
Reduce::ReduceType reduce_type) {
constexpr int N_READS = 8;

out.set_data(allocator::malloc(out.nbytes()));

auto get_args = [](size_t size, int N) {
int threads = std::min(512UL, (size + N - 1) / N);
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int reductions_per_step = threads * N;
size_t steps_needed =
(size + reductions_per_step - 1) / reductions_per_step;

int blocks;
if (steps_needed < 32) {
blocks = 1;
} else if (steps_needed < 128) {
blocks = 32;
} else if (steps_needed < 512) {
blocks = 128;
} else if (steps_needed < 1024) {
blocks = 512;
} else {
blocks = 1024;
}

size_t steps_per_block = (steps_needed + blocks - 1) / blocks;
size_t block_step = steps_per_block * reductions_per_step;

return std::make_tuple(blocks, threads, block_step);
};

int blocks, threads;
size_t block_step;
array x = in;

// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int blocks, threads;
size_t block_step;
array x = in;
// Large array so allocate an intermediate and accumulate there
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
array x = in;
// Large array so allocate an intermediate and accumulate there
auto [blocks, threads, block_step] = get_args(x.size(), N_READS);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that capturing structured bindings is a C++26 feature. I also changed the one in col_reduce to remove the warnings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that capturing structured bindings is a C++26 feature

Ahh 😞 I forgot about that.

if (blocks > 1) {
array intermediate({blocks}, out.dtype(), nullptr, {});
intermediate.set_data(allocator::malloc(intermediate.nbytes()));
encoder.add_temporary(intermediate);
encoder.set_input_array(x);
encoder.set_output_array(intermediate);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), intermediate.data<U>(), block_step, x.size());
});
});
});

// Set the input for the next step and recalculate the blocks
x = intermediate;
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS);
}

encoder.set_input_array(x);
encoder.set_output_array(out);
encoder.launch_kernel([&](cudaStream_t stream) {
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, {
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, {
using T = cuda_type_t<CTYPE>;
using U = cu::ReduceResult<OP, T>::type;
auto kernel = cu::all_reduce<T, U, OP, N_READS>;
kernel<<<blocks, threads, 0, stream>>>(
x.data<T>(), out.data<U>(), block_step, x.size());
});
});
});
}

} // namespace mlx::core
Loading