-
Notifications
You must be signed in to change notification settings - Fork 1.3k
[CUDA] Fix reductions #2314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CUDA] Fix reductions #2314
Conversation
out.set_data(allocator::malloc(out.nbytes())); | ||
} | ||
|
||
encoder.set_input_array(in); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need to add the input here as the following kernel does not depend on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact maybe it makes sense to remove the input from the function signature entirely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately ReduceInit
takes T
not U
for it's type so we need x
, but yeah it isn't an input to the kernel... silly copy-paste.
Btw I haven't decided if ReduceInit
should simply take U
, I can't think of a place where it makes sense to look at T
.
}; | ||
|
||
struct Max { | ||
template <typename T> | ||
__device__ T operator()(T a, T b) { | ||
__device__ __forceinline__ T operator()(T a, T b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, what motivated the __forceinline__
? Just curious if we should be adding it in our other functors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not. I didn't see any perf difference. CUB uses it so I used it to see if it makes a difference 🤷♂️.
int blocks, threads; | ||
size_t block_step; | ||
array x = in; | ||
|
||
// Large array so allocate an intermediate and accumulate there | ||
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int blocks, threads; | |
size_t block_step; | |
array x = in; | |
// Large array so allocate an intermediate and accumulate there | |
std::tie(blocks, threads, block_step) = get_args(x.size(), N_READS); | |
array x = in; | |
// Large array so allocate an intermediate and accumulate there | |
auto [blocks, threads, block_step] = get_args(x.size(), N_READS); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that capturing structured bindings is a C++26 feature. I also changed the one in col_reduce
to remove the warnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that capturing structured bindings is a C++26 feature
Ahh 😞 I forgot about that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome! Left a few minor comments.
Lot's of green! More passing tests! 🚀
I am not too happy with this but I think it is a start.
I attempted an (almost certainly premature) "optimization" that I think complicated the code a bit and is probably not worth it. When the input is transposed instead of reading out of order and writing in order, I opted to transpose the output the same way so that we can read and write in order.
Here are some speed comparisons with PT on matrices
MxN
of type bfloat16.It is obvious that our column reductions need work and we need to make our performance more consistent across sizes as well.
There are more weird issues to clear up like for instance the maximum complex number being
-inf - inf j
which is weird but also there is no real ordering for complex numbers so 🤷♂️