|
15 | 15 | #ifndef PADDLE_WITH_HIP |
16 | 16 | // HIP not support cusolver |
17 | 17 |
|
| 18 | +#include "paddle/phi/kernels/matrix_rank_tol_kernel.h" |
| 19 | + |
18 | 20 | #include <algorithm> |
19 | 21 | #include <vector> |
20 | 22 | #include "paddle/fluid/memory/memory.h" |
21 | 23 | #include "paddle/phi/backends/dynload/cusolver.h" |
22 | 24 | #include "paddle/phi/backends/gpu/gpu_context.h" |
23 | 25 | #include "paddle/phi/core/kernel_registry.h" |
| 26 | +#include "paddle/phi/kernels/abs_kernel.h" |
24 | 27 | #include "paddle/phi/kernels/full_kernel.h" |
25 | 28 | #include "paddle/phi/kernels/funcs/broadcast_function.h" |
26 | 29 | #include "paddle/phi/kernels/funcs/compare_functors.h" |
@@ -350,10 +353,9 @@ void MatrixRankTolKernel(const Context& dev_ctx, |
350 | 353 | if (hermitian) { |
351 | 354 | SyevjBatched<T>( |
352 | 355 | dev_ctx, batches, rows, x_tmp.data<T>(), eigenvalue_data, info_ptr); |
353 | | - phi::funcs::ForRange<Context> for_range(dev_ctx, eigenvalue_tensor.numel()); |
354 | | - phi::funcs::AbsFunctor<T> functor( |
355 | | - eigenvalue_data, eigenvalue_data, eigenvalue_tensor.numel()); |
356 | | - for_range(functor); |
| 356 | + |
| 357 | + phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor); |
| 358 | + |
357 | 359 | } else { |
358 | 360 | DenseTensor U, VH; |
359 | 361 | U.Resize(detail::GetUDDim(dim_x, k)); |
@@ -384,8 +386,8 @@ void MatrixRankTolKernel(const Context& dev_ctx, |
384 | 386 | &max_eigenvalue_tensor); |
385 | 387 |
|
386 | 388 | DenseTensor temp_rtol_tensor; |
387 | | - paddle::framework::TensorFromVector<T>( |
388 | | - std::vector<T>{rtol_T}, dev_ctx, &temp_rtol_tensor); |
| 389 | + temp_rtol_tensor = |
| 390 | + phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T)); |
389 | 391 |
|
390 | 392 | DenseTensor rtol_tensor = |
391 | 393 | phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor); |
@@ -416,13 +418,12 @@ void MatrixRankTolKernel(const Context& dev_ctx, |
416 | 418 | funcs::GreaterThanFunctor<T, int64_t>(), |
417 | 419 | &compare_result); |
418 | 420 |
|
419 | | - DenseTensor result = phi::Sum<int64_t>(dev_ctx, |
420 | | - compare_result, |
421 | | - std::vector<int64_t>{-1}, |
422 | | - compare_result.type(), |
423 | | - false); |
424 | | - |
425 | | - out->ShareDataWith(result); |
| 421 | + phi::SumKernel<int64_t>(dev_ctx, |
| 422 | + compare_result, |
| 423 | + std::vector<int64_t>{-1}, |
| 424 | + compare_result.dtype(), |
| 425 | + false, |
| 426 | + out); |
426 | 427 | } |
427 | 428 |
|
428 | 429 | } // namespace phi |
|
0 commit comments