|
12 | 12 | // See the License for the specific language governing permissions and |
13 | 13 | // limitations under the License. |
14 | 14 |
|
15 | | -#include "paddle/fluid/operators/matrix_rank_op.h" |
16 | 15 | #include <memory> |
17 | 16 | #include <string> |
18 | 17 | #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" |
@@ -70,9 +69,9 @@ class MatrixRankeOp : public framework::OperatorWithKernel { |
70 | 69 | std::vector<int> x_batch_dims_array(max_dim); |
71 | 70 | std::vector<int> tol_dims_array(max_dim); |
72 | 71 | std::vector<int> out_dims_array(max_dim); |
73 | | - GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(), |
74 | | - tol_dims_array.data(), out_dims_array.data(), |
75 | | - max_dim, axis); |
| 72 | + phi::funcs::GetBroadcastDimsArrays( |
| 73 | + dim_x_batch, dim_tol, x_batch_dims_array.data(), |
| 74 | + tol_dims_array.data(), out_dims_array.data(), max_dim, axis); |
76 | 75 | ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array)); |
77 | 76 | } |
78 | 77 | } else { |
@@ -115,141 +114,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker { |
115 | 114 | } |
116 | 115 | }; |
117 | 116 |
|
118 | | -template <typename T> |
119 | | -void BatchEigenvalues(const T* x_data, T* eigenvalues_data, int batches, |
120 | | - int rows, int cols, int k) { |
121 | | - // Eigen::Matrix API need non-const pointer. |
122 | | - T* input = const_cast<T*>(x_data); |
123 | | - int stride = rows * cols; |
124 | | - for (int i = 0; i < batches; i++) { |
125 | | - auto m = Eigen::Map< |
126 | | - Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( |
127 | | - input + i * stride, rows, rows); |
128 | | - Eigen::SelfAdjointEigenSolver< |
129 | | - Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
130 | | - eigen_solver(m); |
131 | | - auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs(); |
132 | | - for (int j = 0; j < k; j++) { |
133 | | - *(eigenvalues_data + i * k + j) = eigenvalues[j]; |
134 | | - } |
135 | | - } |
136 | | -} |
137 | | - |
138 | | -template <typename T> |
139 | | -void BatchSVD(const T* x_data, T* eigenvalues_data, int batches, int rows, |
140 | | - int cols, int k) { |
141 | | - // Eigen::Matrix API need non-const pointer. |
142 | | - T* input = const_cast<T*>(x_data); |
143 | | - int stride = rows * cols; |
144 | | - Eigen::BDCSVD< |
145 | | - Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> |
146 | | - svd; |
147 | | - for (int i = 0; i < batches; i++) { |
148 | | - auto m = Eigen::Map< |
149 | | - Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>( |
150 | | - input + i * stride, rows, cols); |
151 | | - svd.compute(m); |
152 | | - auto res_s = svd.singularValues(); |
153 | | - for (int j = 0; j < k; j++) { |
154 | | - eigenvalues_data[i * k + j] = res_s[j]; |
155 | | - } |
156 | | - } |
157 | | -} |
158 | | - |
159 | | -template <typename T> |
160 | | -class MatrixRankCPUKernel : public framework::OpKernel<T> { |
161 | | - public: |
162 | | - void Compute(const framework::ExecutionContext& context) const override { |
163 | | - const Tensor* x = context.Input<Tensor>("X"); |
164 | | - auto* x_data = x->data<T>(); |
165 | | - auto* out = context.Output<Tensor>("Out"); |
166 | | - out->mutable_data<int64_t>(context.GetPlace()); |
167 | | - bool hermitian = context.Attr<bool>("hermitian"); |
168 | | - |
169 | | - auto dim_x = x->dims(); |
170 | | - auto dim_out = out->dims(); |
171 | | - int rows = dim_x[dim_x.size() - 2]; |
172 | | - int cols = dim_x[dim_x.size() - 1]; |
173 | | - int k = std::min(rows, cols); |
174 | | - auto numel = x->numel(); |
175 | | - int batches = numel / (rows * cols); |
176 | | - |
177 | | - bool use_default_tol = context.Attr<bool>("use_default_tol"); |
178 | | - const Tensor* atol_tensor = nullptr; |
179 | | - Tensor temp_tensor; |
180 | | - T rtol_T = 0; |
181 | | - if (use_default_tol) { |
182 | | - framework::TensorFromVector<T>(std::vector<T>{0}, |
183 | | - context.device_context(), &temp_tensor); |
184 | | - atol_tensor = &temp_tensor; |
185 | | - rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols); |
186 | | - } else if (context.HasInput("TolTensor")) { |
187 | | - atol_tensor = context.Input<Tensor>("TolTensor"); |
188 | | - } else { |
189 | | - framework::TensorFromVector<T>(std::vector<T>{context.Attr<float>("tol")}, |
190 | | - context.device_context(), &temp_tensor); |
191 | | - atol_tensor = &temp_tensor; |
192 | | - } |
193 | | - |
194 | | - Tensor eigenvalue_tensor; |
195 | | - auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>( |
196 | | - detail::GetEigenvalueDim(dim_x, k), context.GetPlace()); |
197 | | - if (hermitian) { |
198 | | - BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k); |
199 | | - } else { |
200 | | - BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k); |
201 | | - } |
202 | | - |
203 | | - auto dito_T = |
204 | | - math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>( |
205 | | - context); |
206 | | - std::vector<int> max_eigenvalue_shape = |
207 | | - phi::vectorize<int>(detail::RemoveLastDim(eigenvalue_tensor.dims())); |
208 | | - Tensor max_eigenvalue_tensor = |
209 | | - dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape); |
210 | | - |
211 | | - Tensor temp_rtol_tensor; |
212 | | - framework::TensorFromVector<T>(std::vector<T>{rtol_T}, &temp_rtol_tensor); |
213 | | - Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor); |
214 | | - Tensor tol_tensor; |
215 | | - tol_tensor.mutable_data<T>(dim_out, context.GetPlace()); |
216 | | - ElementwiseComputeEx<GreaterElementFunctor<T>, platform::CPUDeviceContext, |
217 | | - T, T>(context, atol_tensor, &rtol_tensor, -1, |
218 | | - GreaterElementFunctor<T>(), &tol_tensor); |
219 | | - |
220 | | - tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1)); |
221 | | - |
222 | | - Tensor compare_result; |
223 | | - compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k), |
224 | | - context.GetPlace()); |
225 | | - |
226 | | - int axis = -1; |
227 | | - if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) { |
228 | | - ElementwiseComputeEx<phi::funcs::GreaterThanFunctor<T, int64_t>, |
229 | | - platform::CPUDeviceContext, T, int>( |
230 | | - context, &eigenvalue_tensor, &tol_tensor, axis, |
231 | | - phi::funcs::GreaterThanFunctor<T, int64_t>(), &compare_result); |
232 | | - } else { |
233 | | - ElementwiseComputeEx<phi::funcs::LessThanFunctor<T, int64_t>, |
234 | | - platform::CPUDeviceContext, T, int>( |
235 | | - context, &eigenvalue_tensor, &tol_tensor, axis, |
236 | | - phi::funcs::LessThanFunctor<T, int64_t>(), &compare_result); |
237 | | - } |
238 | | - auto dito_int = |
239 | | - math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, |
240 | | - int64_t>(context); |
241 | | - std::vector<int> result_shape = phi::vectorize<int>(dim_out); |
242 | | - Tensor result = dito_int.ReduceSum(compare_result, result_shape); |
243 | | - out->ShareDataWith(result); |
244 | | - } |
245 | | -}; |
246 | | - |
247 | 117 | } // namespace operators |
248 | 118 | } // namespace paddle |
249 | 119 |
|
250 | 120 | namespace ops = paddle::operators; |
251 | 121 |
|
252 | 122 | REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker); |
253 | | - |
254 | | -REGISTER_OP_CPU_KERNEL(matrix_rank, ops::MatrixRankCPUKernel<float>, |
255 | | - ops::MatrixRankCPUKernel<double>); |
0 commit comments