Skip to content

Commit b622684

Browse files
committed
optimize code
1 parent 7f6052d commit b622684

File tree

7 files changed

+40
-45
lines changed

7 files changed

+40
-45
lines changed

paddle/phi/kernels/cpu/matrix_rank_kernel.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/framework/tensor_util.h"
16-
#include "paddle/phi/core/kernel_registry.h"
15+
#include "paddle/phi/kernels/matrix_rank_kernel.h"
1716
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
1817

18+
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/kernels/full_kernel.h"
20+
1921
namespace phi {
2022

2123
template <typename T, typename Context>
@@ -27,11 +29,9 @@ void MatrixRankKernel(const Context& dev_ctx,
2729
DenseTensor* out) {
2830
DenseTensor atol_tensor;
2931
if (use_default_tol) {
30-
paddle::framework::TensorFromVector(
31-
std::vector<T>{0}, dev_ctx, &atol_tensor);
32+
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(0));
3233
} else {
33-
paddle::framework::TensorFromVector(
34-
std::vector<T>{tol}, dev_ctx, &atol_tensor);
34+
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(tol));
3535
}
3636
MatrixRankTolKernel<T, Context>(
3737
dev_ctx, x, atol_tensor, use_default_tol, hermitian, out);

paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
16+
1517
#include <Eigen/Dense>
1618
#include <Eigen/SVD>
17-
#include <memory>
18-
#include <string>
19-
#include <vector>
2019
#include "paddle/phi/core/ddim.h"
2120
#include "paddle/phi/core/kernel_registry.h"
2221
#include "paddle/phi/kernels/cpu/reduce.h"
@@ -124,8 +123,9 @@ void MatrixRankTolKernel(const Context& dev_ctx,
124123
&max_eigenvalue_tensor);
125124

126125
DenseTensor temp_rtol_tensor;
127-
paddle::framework::TensorFromVector<T>(std::vector<T>{rtol_T},
128-
&temp_rtol_tensor);
126+
temp_rtol_tensor =
127+
phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T));
128+
129129
DenseTensor rtol_tensor =
130130
phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor);
131131

@@ -163,12 +163,13 @@ void MatrixRankTolKernel(const Context& dev_ctx,
163163
funcs::LessThanFunctor<T, int64_t>(),
164164
&compare_result);
165165
}
166-
DenseTensor result = phi::Sum<int64_t>(dev_ctx,
167-
compare_result,
168-
std::vector<int64_t>{-1},
169-
compare_result.dtype(),
170-
false);
171-
out->ShareDataWith(result);
166+
167+
phi::SumKernel<int64_t>(dev_ctx,
168+
compare_result,
169+
std::vector<int64_t>{-1},
170+
compare_result.dtype(),
171+
false,
172+
out);
172173
}
173174
} // namespace phi
174175

paddle/phi/kernels/gpu/matrix_rank_kernel.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
#ifndef PADDLE_WITH_HIP
1616
// HIP not support cusolver
1717

18-
#include "paddle/fluid/framework/tensor_util.h"
19-
#include "paddle/phi/core/kernel_registry.h"
18+
#include "paddle/phi/kernels/matrix_rank_kernel.h"
2019
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
2120

21+
#include "paddle/phi/core/kernel_registry.h"
22+
#include "paddle/phi/kernels/full_kernel.h"
23+
2224
namespace phi {
2325

2426
template <typename T, typename Context>
@@ -30,11 +32,9 @@ void MatrixRankKernel(const Context& dev_ctx,
3032
DenseTensor* out) {
3133
DenseTensor atol_tensor;
3234
if (use_default_tol) {
33-
paddle::framework::TensorFromVector(
34-
std::vector<T>{0}, dev_ctx, &atol_tensor);
35+
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(0));
3536
} else {
36-
paddle::framework::TensorFromVector(
37-
std::vector<T>{tol}, dev_ctx, &atol_tensor);
37+
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(tol));
3838
}
3939
MatrixRankTolKernel<T, Context>(
4040
dev_ctx, x, atol_tensor, use_default_tol, hermitian, out);

paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
#ifndef PADDLE_WITH_HIP
1616
// HIP not support cusolver
1717

18+
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
19+
1820
#include <algorithm>
1921
#include <vector>
2022
#include "paddle/fluid/memory/memory.h"
2123
#include "paddle/phi/backends/dynload/cusolver.h"
2224
#include "paddle/phi/backends/gpu/gpu_context.h"
2325
#include "paddle/phi/core/kernel_registry.h"
26+
#include "paddle/phi/kernels/abs_kernel.h"
2427
#include "paddle/phi/kernels/full_kernel.h"
2528
#include "paddle/phi/kernels/funcs/broadcast_function.h"
2629
#include "paddle/phi/kernels/funcs/compare_functors.h"
@@ -350,10 +353,9 @@ void MatrixRankTolKernel(const Context& dev_ctx,
350353
if (hermitian) {
351354
SyevjBatched<T>(
352355
dev_ctx, batches, rows, x_tmp.data<T>(), eigenvalue_data, info_ptr);
353-
phi::funcs::ForRange<Context> for_range(dev_ctx, eigenvalue_tensor.numel());
354-
phi::funcs::AbsFunctor<T> functor(
355-
eigenvalue_data, eigenvalue_data, eigenvalue_tensor.numel());
356-
for_range(functor);
356+
357+
phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
358+
357359
} else {
358360
DenseTensor U, VH;
359361
U.Resize(detail::GetUDDim(dim_x, k));
@@ -384,8 +386,8 @@ void MatrixRankTolKernel(const Context& dev_ctx,
384386
&max_eigenvalue_tensor);
385387

386388
DenseTensor temp_rtol_tensor;
387-
paddle::framework::TensorFromVector<T>(
388-
std::vector<T>{rtol_T}, dev_ctx, &temp_rtol_tensor);
389+
temp_rtol_tensor =
390+
phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T));
389391

390392
DenseTensor rtol_tensor =
391393
phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor);
@@ -416,13 +418,12 @@ void MatrixRankTolKernel(const Context& dev_ctx,
416418
funcs::GreaterThanFunctor<T, int64_t>(),
417419
&compare_result);
418420

419-
DenseTensor result = phi::Sum<int64_t>(dev_ctx,
420-
compare_result,
421-
std::vector<int64_t>{-1},
422-
compare_result.type(),
423-
false);
424-
425-
out->ShareDataWith(result);
421+
phi::SumKernel<int64_t>(dev_ctx,
422+
compare_result,
423+
std::vector<int64_t>{-1},
424+
compare_result.dtype(),
425+
false,
426+
out);
426427
}
427428

428429
} // namespace phi

paddle/phi/kernels/impl/matrix_rank_kernel_impl.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
namespace phi {
2121

22-
using DDim = phi::DDim;
2322
namespace detail {
2423
static DDim GetEigenvalueDim(const DDim& dim, int k) {
2524
auto vec = phi::vectorize(dim);

paddle/phi/kernels/matrix_rank_kernel.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include "paddle/phi/common/scalar.h"
1817
#include "paddle/phi/core/dense_tensor.h"
19-
#include "paddle/phi/core/selected_rows.h"
20-
#include "paddle/phi/infermeta/unary.h"
21-
#include "paddle/phi/kernels/empty_kernel.h"
18+
2219
namespace phi {
2320

2421
template <typename T, typename Context>

paddle/phi/kernels/matrix_rank_tol_kernel.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include "paddle/phi/common/scalar.h"
1817
#include "paddle/phi/core/dense_tensor.h"
19-
#include "paddle/phi/core/selected_rows.h"
20-
#include "paddle/phi/infermeta/unary.h"
21-
#include "paddle/phi/kernels/empty_kernel.h"
18+
2219
namespace phi {
2320

2421
template <typename T, typename Context>

0 commit comments

Comments
 (0)