Skip to content

Commit a46df40

Browse files
compilation optimization for dirichlet_kernel (#57815)
1 parent e497279 commit a46df40

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

paddle/phi/kernels/gpu/dirichlet_kernel.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
#include "paddle/phi/backends/gpu/gpu_context.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
1920
#include "paddle/phi/kernels/funcs/broadcast_function.h"
2021
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
2122
#include "paddle/phi/kernels/funcs/for_range.h"
2223
#include "paddle/phi/kernels/funcs/reduce_function.h"
2324
#include "paddle/phi/kernels/funcs/reduce_functor.h"
2425
#include "paddle/phi/kernels/impl/dirichlet_kernel_impl.h"
26+
#include "paddle/phi/kernels/reduce_sum_kernel.h"
2527

2628
#ifdef PADDLE_WITH_CUDA
2729
#include <curand_kernel.h>
@@ -99,15 +101,14 @@ struct DirichletSampler<GPUContext, T> {
99101
gamma_sum.Resize(new_shape);
100102
dev_ctx.template Alloc<T>(&gamma_sum);
101103

102-
funcs::ReduceKernelImpl<GPUContext, T, T, funcs::SumFunctor>(
103-
dev_ctx,
104-
gamma_samples,
105-
&gamma_sum,
106-
{new_shape.size() - 1},
107-
true,
108-
false);
109-
funcs::ElementwiseCompute<funcs::DivideFunctor<T>, T>(
110-
dev_ctx, gamma_samples, gamma_sum, funcs::DivideFunctor<T>(), out);
104+
phi::SumRawKernel<T, GPUContext>(dev_ctx,
105+
gamma_samples,
106+
{new_shape.size() - 1},
107+
true,
108+
false,
109+
gamma_sum.dtype(),
110+
&gamma_sum);
111+
phi::DivideKernel<T, GPUContext>(dev_ctx, gamma_samples, gamma_sum, out);
111112
}
112113
};
113114
} // namespace phi

0 commit comments

Comments
 (0)