Skip to content

Commit b9d4285

Browse files
authored
【phi】migrate matrix_rank to phi (#40074)
* migrate matrix_rank to phi * migrate eigh and matrix_rank to phi * fix matrix_rank * optimize code * move matrix_rank to phi * add max functor * migrate matrix_rank to phi * optimize code
1 parent edd97f9 commit b9d4285

File tree

10 files changed

+828
-462
lines changed

10 files changed

+828
-462
lines changed

paddle/fluid/operators/matrix_rank_op.cc

Lines changed: 3 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/fluid/operators/matrix_rank_op.h"
1615
#include <memory>
1716
#include <string>
1817
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
@@ -70,9 +69,9 @@ class MatrixRankeOp : public framework::OperatorWithKernel {
7069
std::vector<int> x_batch_dims_array(max_dim);
7170
std::vector<int> tol_dims_array(max_dim);
7271
std::vector<int> out_dims_array(max_dim);
73-
GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(),
74-
tol_dims_array.data(), out_dims_array.data(),
75-
max_dim, axis);
72+
phi::funcs::GetBroadcastDimsArrays(
73+
dim_x_batch, dim_tol, x_batch_dims_array.data(),
74+
tol_dims_array.data(), out_dims_array.data(), max_dim, axis);
7675
ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
7776
}
7877
} else {
@@ -115,141 +114,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker {
115114
}
116115
};
117116

118-
template <typename T>
119-
void BatchEigenvalues(const T* x_data, T* eigenvalues_data, int batches,
120-
int rows, int cols, int k) {
121-
// Eigen::Matrix API need non-const pointer.
122-
T* input = const_cast<T*>(x_data);
123-
int stride = rows * cols;
124-
for (int i = 0; i < batches; i++) {
125-
auto m = Eigen::Map<
126-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
127-
input + i * stride, rows, rows);
128-
Eigen::SelfAdjointEigenSolver<
129-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
130-
eigen_solver(m);
131-
auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs();
132-
for (int j = 0; j < k; j++) {
133-
*(eigenvalues_data + i * k + j) = eigenvalues[j];
134-
}
135-
}
136-
}
137-
138-
template <typename T>
139-
void BatchSVD(const T* x_data, T* eigenvalues_data, int batches, int rows,
140-
int cols, int k) {
141-
// Eigen::Matrix API need non-const pointer.
142-
T* input = const_cast<T*>(x_data);
143-
int stride = rows * cols;
144-
Eigen::BDCSVD<
145-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
146-
svd;
147-
for (int i = 0; i < batches; i++) {
148-
auto m = Eigen::Map<
149-
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
150-
input + i * stride, rows, cols);
151-
svd.compute(m);
152-
auto res_s = svd.singularValues();
153-
for (int j = 0; j < k; j++) {
154-
eigenvalues_data[i * k + j] = res_s[j];
155-
}
156-
}
157-
}
158-
159-
template <typename T>
160-
class MatrixRankCPUKernel : public framework::OpKernel<T> {
161-
public:
162-
void Compute(const framework::ExecutionContext& context) const override {
163-
const Tensor* x = context.Input<Tensor>("X");
164-
auto* x_data = x->data<T>();
165-
auto* out = context.Output<Tensor>("Out");
166-
out->mutable_data<int64_t>(context.GetPlace());
167-
bool hermitian = context.Attr<bool>("hermitian");
168-
169-
auto dim_x = x->dims();
170-
auto dim_out = out->dims();
171-
int rows = dim_x[dim_x.size() - 2];
172-
int cols = dim_x[dim_x.size() - 1];
173-
int k = std::min(rows, cols);
174-
auto numel = x->numel();
175-
int batches = numel / (rows * cols);
176-
177-
bool use_default_tol = context.Attr<bool>("use_default_tol");
178-
const Tensor* atol_tensor = nullptr;
179-
Tensor temp_tensor;
180-
T rtol_T = 0;
181-
if (use_default_tol) {
182-
framework::TensorFromVector<T>(std::vector<T>{0},
183-
context.device_context(), &temp_tensor);
184-
atol_tensor = &temp_tensor;
185-
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
186-
} else if (context.HasInput("TolTensor")) {
187-
atol_tensor = context.Input<Tensor>("TolTensor");
188-
} else {
189-
framework::TensorFromVector<T>(std::vector<T>{context.Attr<float>("tol")},
190-
context.device_context(), &temp_tensor);
191-
atol_tensor = &temp_tensor;
192-
}
193-
194-
Tensor eigenvalue_tensor;
195-
auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>(
196-
detail::GetEigenvalueDim(dim_x, k), context.GetPlace());
197-
if (hermitian) {
198-
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
199-
} else {
200-
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
201-
}
202-
203-
auto dito_T =
204-
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(
205-
context);
206-
std::vector<int> max_eigenvalue_shape =
207-
phi::vectorize<int>(detail::RemoveLastDim(eigenvalue_tensor.dims()));
208-
Tensor max_eigenvalue_tensor =
209-
dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape);
210-
211-
Tensor temp_rtol_tensor;
212-
framework::TensorFromVector<T>(std::vector<T>{rtol_T}, &temp_rtol_tensor);
213-
Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor);
214-
Tensor tol_tensor;
215-
tol_tensor.mutable_data<T>(dim_out, context.GetPlace());
216-
ElementwiseComputeEx<GreaterElementFunctor<T>, platform::CPUDeviceContext,
217-
T, T>(context, atol_tensor, &rtol_tensor, -1,
218-
GreaterElementFunctor<T>(), &tol_tensor);
219-
220-
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
221-
222-
Tensor compare_result;
223-
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
224-
context.GetPlace());
225-
226-
int axis = -1;
227-
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
228-
ElementwiseComputeEx<phi::funcs::GreaterThanFunctor<T, int64_t>,
229-
platform::CPUDeviceContext, T, int>(
230-
context, &eigenvalue_tensor, &tol_tensor, axis,
231-
phi::funcs::GreaterThanFunctor<T, int64_t>(), &compare_result);
232-
} else {
233-
ElementwiseComputeEx<phi::funcs::LessThanFunctor<T, int64_t>,
234-
platform::CPUDeviceContext, T, int>(
235-
context, &eigenvalue_tensor, &tol_tensor, axis,
236-
phi::funcs::LessThanFunctor<T, int64_t>(), &compare_result);
237-
}
238-
auto dito_int =
239-
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext,
240-
int64_t>(context);
241-
std::vector<int> result_shape = phi::vectorize<int>(dim_out);
242-
Tensor result = dito_int.ReduceSum(compare_result, result_shape);
243-
out->ShareDataWith(result);
244-
}
245-
};
246-
247117
} // namespace operators
248118
} // namespace paddle
249119

250120
namespace ops = paddle::operators;
251121

252122
REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker);
253-
254-
REGISTER_OP_CPU_KERNEL(matrix_rank, ops::MatrixRankCPUKernel<float>,
255-
ops::MatrixRankCPUKernel<double>);

0 commit comments

Comments
 (0)