-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[CUDA] Fix reductions #2314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
[CUDA] Fix reductions #2314
Changes from 15 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
ab7c310
Adapt the torch benchmark to run in CUDA
angeloskath 9cf7ef1
Add all reduce and atomic updates
angeloskath b70a964
Optimize all reduce a bit
angeloskath 4d2b682
Simple row reduce
angeloskath cd523ff
Working row reduce looped
angeloskath 880751a
Remove segmented reduce and fix row reduce
angeloskath abdb21f
Add helpers and atomic kernel
angeloskath 664d8e4
Add comments and clean up
angeloskath cc4b995
Working col reduce
angeloskath 818e8e6
Add an init reduce
angeloskath fd1d082
Make sure softmax doesn't change the actual max
angeloskath 8bd4bf2
Fixes for transpositions and expands
angeloskath a57a75b
More fixes for all reductions
angeloskath a7faa04
Add a special case when not keeping the dims
angeloskath d999675
Make check more general
angeloskath bc60a31
Comments
angeloskath File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright © 2025 Apple Inc. | ||
|
||
#include "mlx/backend/cuda/device.h" | ||
#include "mlx/backend/cuda/reduce/reduce.cuh" | ||
|
||
#include <cooperative_groups.h> | ||
#include <cooperative_groups/reduce.h> | ||
#include <cub/block/block_load.cuh> | ||
|
||
namespace mlx::core { | ||
|
||
namespace cu { | ||
|
||
namespace cg = cooperative_groups; | ||
|
||
template <typename T, typename U, typename ReduceOp, int N = 4> | ||
__global__ void all_reduce(T* in, U* out, size_t block_step, size_t size) { | ||
// TODO: Process multiple "rows" in each thread | ||
constexpr int M = 1; | ||
|
||
auto grid = cg::this_grid(); | ||
auto block = cg::this_thread_block(); | ||
auto warp = cg::tiled_partition<WARP_SIZE>(block); | ||
|
||
const U init = cu::ReduceInit<ReduceOp, T>::value(); | ||
ReduceOp op; | ||
|
||
T vals[N]; | ||
U accs[M]; | ||
accs[0] = init; | ||
|
||
size_t start = grid.block_rank() * block_step; | ||
size_t end = start + block_step; | ||
size_t check = min(end, size); | ||
|
||
size_t i = start; | ||
for (; i + block.size() * N <= check; i += block.size() * N) { | ||
cub::LoadDirectBlockedVectorized<T, N>(block.thread_rank(), in + i, vals); | ||
for (int j = 0; j < N; j++) { | ||
accs[0] = op(accs[0], __cast<U, T>(vals[j])); | ||
} | ||
} | ||
|
||
if (i < check) { | ||
cub::LoadDirectBlocked( | ||
block.thread_rank(), in + i, vals, check - i, __cast<T, U>(init)); | ||
for (int i = 0; i < N; i++) { | ||
accs[0] = op(accs[0], __cast<U, T>(vals[i])); | ||
} | ||
} | ||
|
||
__shared__ U shared_accumulators[32]; | ||
block_reduce(block, warp, accs, shared_accumulators, op, init); | ||
|
||
if (block.thread_rank() == 0) { | ||
out[grid.block_rank()] = accs[0]; | ||
} | ||
} | ||
|
||
} // namespace cu | ||
|
||
void all_reduce( | ||
cu::CommandEncoder& encoder, | ||
const array& in, | ||
array& out, | ||
Reduce::ReduceType reduce_type) { | ||
constexpr int N_READS = 8; | ||
|
||
out.set_data(allocator::malloc(out.nbytes())); | ||
|
||
auto get_args = [](size_t size, int N) { | ||
int threads = std::min(512UL, (size + N - 1) / N); | ||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; | ||
int reductions_per_step = threads * N; | ||
size_t steps_needed = | ||
(size + reductions_per_step - 1) / reductions_per_step; | ||
|
||
int blocks; | ||
if (steps_needed < 32) { | ||
blocks = 1; | ||
} else if (steps_needed < 128) { | ||
blocks = 32; | ||
} else if (steps_needed < 512) { | ||
blocks = 128; | ||
} else if (steps_needed < 1024) { | ||
blocks = 512; | ||
} else { | ||
blocks = 1024; | ||
} | ||
|
||
size_t steps_per_block = (steps_needed + blocks - 1) / blocks; | ||
size_t block_step = steps_per_block * reductions_per_step; | ||
|
||
return std::make_tuple(blocks, threads, block_step); | ||
}; | ||
|
||
int blocks, threads; | ||
size_t block_step; | ||
array x = in; | ||
|
||
// Large array so allocate an intermediate and accumulate there | ||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); | ||
if (blocks > 1) { | ||
array intermediate({blocks}, out.dtype(), nullptr, {}); | ||
intermediate.set_data(allocator::malloc(intermediate.nbytes())); | ||
encoder.add_temporary(intermediate); | ||
encoder.set_input_array(x); | ||
encoder.set_output_array(intermediate); | ||
encoder.launch_kernel([&](cudaStream_t stream) { | ||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { | ||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { | ||
using T = cuda_type_t<CTYPE>; | ||
using U = cu::ReduceResult<OP, T>::type; | ||
auto kernel = cu::all_reduce<T, U, OP, N_READS>; | ||
kernel<<<blocks, threads, 0, stream>>>( | ||
x.data<T>(), intermediate.data<U>(), block_step, x.size()); | ||
}); | ||
}); | ||
}); | ||
|
||
// Set the input for the next step and recalculate the blocks | ||
x = intermediate; | ||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); | ||
} | ||
|
||
encoder.set_input_array(x); | ||
encoder.set_output_array(out); | ||
encoder.launch_kernel([&](cudaStream_t stream) { | ||
MLX_SWITCH_ALL_TYPES(x.dtype(), CTYPE, { | ||
MLX_SWITCH_REDUCE_OPS(reduce_type, OP, { | ||
using T = cuda_type_t<CTYPE>; | ||
using U = cu::ReduceResult<OP, T>::type; | ||
auto kernel = cu::all_reduce<T, U, OP, N_READS>; | ||
kernel<<<blocks, threads, 0, stream>>>( | ||
x.data<T>(), out.data<U>(), block_step, x.size()); | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
} // namespace mlx::core |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that capturing structured bindings is a C++26 feature. I also changed the one in
col_reduce
to remove the warnings.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh 😞 I forgot about that.