Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions mlx/backend/cuda/arg_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ struct ArgMin {
}

template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) {
if (vals[i] < best.val) {
best.val = vals[i];
Expand Down Expand Up @@ -74,8 +77,11 @@ struct ArgMax {
}

template <int N>
__device__ IndexValPair<T>
reduce_many(IndexValPair<T> best, T (&vals)[N], uint32_t offset) {
__device__ IndexValPair<T> reduce_many(
IndexValPair<T> best,
const AlignedVector<T, N>& vals,
uint32_t offset) {
#pragma unroll
for (int i = 0; i < N; i++) {
if (vals[i] > best.val) {
best.val = vals[i];
Expand Down Expand Up @@ -106,16 +112,15 @@ __global__ void arg_reduce_general(

int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
in += in_idx;

Op op;
T init = op.init();
IndexValPair<T> best{0, init};

for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
T vals[N_READS];
auto tid = r * BLOCK_DIM + block.thread_index().x;
cub::LoadDirectBlocked(
tid, StridedIterator(in + in_idx, axis_stride), vals, axis_size, init);
auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
best = op.reduce_many(best, vals, tid * N_READS);
}

Expand Down
14 changes: 0 additions & 14 deletions mlx/backend/cuda/device/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,6 @@ inline __device__ void store_vector(
}
}

// Helper for accessing strided data.
template <typename T>
struct StridedIterator {
T it;
int64_t stride;

__host__ __device__ StridedIterator(T it, int64_t stride)
: it(it), stride(stride) {}

__host__ __device__ auto operator[](int i) const {
return it[i * stride];
}
};

///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
Expand Down