|
16 | 16 |
|
17 | 17 | #include "paddle/phi/backends/gpu/gpu_context.h" |
18 | 18 | #include "paddle/phi/core/kernel_registry.h" |
| 19 | +#include "paddle/phi/kernels/elementwise_divide_kernel.h" |
19 | 20 | #include "paddle/phi/kernels/funcs/broadcast_function.h" |
20 | 21 | #include "paddle/phi/kernels/funcs/elementwise_functor.h" |
21 | 22 | #include "paddle/phi/kernels/funcs/for_range.h" |
22 | 23 | #include "paddle/phi/kernels/funcs/reduce_function.h" |
23 | 24 | #include "paddle/phi/kernels/funcs/reduce_functor.h" |
24 | 25 | #include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h" |
| 26 | +#include "paddle/phi/kernels/reduce_sum_kernel.h" |
25 | 27 |
|
26 | 28 | #ifdef PADDLE_WITH_CUDA |
27 | 29 | #include <curand_kernel.h> |
@@ -99,15 +101,14 @@ struct DirichletSampler<GPUContext, T> { |
99 | 101 | gamma_sum.Resize(new_shape); |
100 | 102 | dev_ctx.template Alloc<T>(&gamma_sum); |
101 | 103 |
|
102 | | - funcs::ReduceKernelImpl<GPUContext, T, T, funcs::SumFunctor>( |
103 | | - dev_ctx, |
104 | | - gamma_samples, |
105 | | - &gamma_sum, |
106 | | - {new_shape.size() - 1}, |
107 | | - true, |
108 | | - false); |
109 | | - funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>( |
110 | | - dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out); |
| 104 | + phi::SumRawKernel<T, GPUContext>(dev_ctx, |
| 105 | + gamma_samples, |
| 106 | + {new_shape.size() - 1}, |
| 107 | + true, |
| 108 | + false, |
| 109 | + gamma_sum.dtype(), |
| 110 | + &gamma_sum); |
| 111 | + phi::DivideKernel<T, GPUContext>(dev_ctx, gamma_samples, gamma_sum, out); |
111 | 112 | } |
112 | 113 | }; |
113 | 114 | } // namespace phi |
|
0 commit comments