We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5b818c0 commit b694b24Copy full SHA for b694b24
csrc/moe/topk_softmax_kernels.cu
@@ -20,6 +20,7 @@
20
#include <ATen/cuda/CUDAContext.h>
21
#include <c10/cuda/CUDAGuard.h>
22
#include "../cuda_compat.h"
23
+#include <cuda/std/functional>
24
25
#ifndef USE_ROCM
26
#include <cub/util_type.cuh>
@@ -62,7 +63,7 @@ __launch_bounds__(TPB) __global__
62
63
64
const int thread_row_offset = blockIdx.x * num_cols;
65
- cub::Sum sum;
66
+ cuda::std::plus<float> sum;
67
float threadData(-FLT_MAX);
68
69
// Don't touch finished rows.
0 commit comments