Skip to content

Commit 930a513

Browse files
authored
[Phi] Migrate triangular_solve dependence to phi (#40417)
1 parent 89a70c7 commit 930a513

File tree

9 files changed

+31
-188
lines changed

9 files changed

+31
-188
lines changed

paddle/fluid/operators/lstsq_op.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
#include <string>
1919
#include <vector>
20+
#include "paddle/fluid/framework/phi_utils.h"
2021
#include "paddle/fluid/operators/lstsq_op.h"
2122
#include "paddle/fluid/operators/qr_op.h"
2223
#include "paddle/fluid/platform/dynload/cusolver.h"
24+
#include "paddle/phi/kernels/triangular_solve_kernel.h"
2325

2426
namespace paddle {
2527
namespace operators {
@@ -70,6 +72,10 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
7072
Tensor tau = dito.Fill(tau_dims_vec, 0);
7173
auto tau_data = tau.mutable_data<T>(context.GetPlace());
7274

75+
using Context =
76+
typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
77+
auto& phi_dev_ctx = static_cast<const Context&>(dev_ctx);
78+
7379
if (m >= n) {
7480
Tensor tmp_x = dito.Transpose(new_x);
7581
Tensor tmp_y = dito.Transpose(new_y);
@@ -93,8 +99,9 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
9399
Tensor slice_y = dito.Slice(trans_y, {-2}, {0}, {min_mn});
94100

95101
// Step 3, solve R X = Y
96-
triangular_solve<DeviceContext, T>(dev_ctx, res_r, slice_y, solution,
97-
true, false, false);
102+
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, res_r, slice_y, true,
103+
false, false, solution);
104+
98105
} else {
99106
auto x_data = new_x.mutable_data<T>(context.GetPlace());
100107
auto y_data = new_y.mutable_data<T>(context.GetPlace());
@@ -105,8 +112,8 @@ class LstsqCUDAKernel : public framework::OpKernel<T> {
105112

106113
// Step 2, solve R^H Z = Y
107114
Tensor trans_r = dito.Transpose(new_x);
108-
triangular_solve<DeviceContext, T>(dev_ctx, trans_r, new_y, solution,
109-
true, true, false);
115+
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, trans_r, new_y, true,
116+
true, false, solution);
110117

111118
// Step 3, X <- Q Z
112119
BatchedOrgqr<DeviceContext, T>(dev_ctx, batch_count, n, n, min_mn, x_data,

paddle/fluid/operators/lstsq_op.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
#include "paddle/fluid/operators/math/matrix_solve.h"
2323
#include "paddle/fluid/operators/svd_helper.h"
2424
#include "paddle/fluid/operators/transpose_op.h"
25-
#include "paddle/fluid/operators/triangular_solve_op.h"
2625
#include "paddle/fluid/platform/for_range.h"
2726
#include "paddle/phi/kernels/funcs/complex_functors.h"
2827
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"

paddle/fluid/operators/lu_op.h

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ limitations under the License. */
1515
#pragma once
1616

1717
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/framework/phi_utils.h"
1819
#include "paddle/fluid/operators/set_value_op.h"
1920
#include "paddle/fluid/operators/svd_helper.h"
20-
#include "paddle/fluid/operators/triangular_solve_op.h"
2121
#include "paddle/fluid/operators/tril_triu_op.h"
2222
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
2323
#include "paddle/phi/kernels/math_kernel.h"
24+
#include "paddle/phi/kernels/triangular_solve_kernel.h"
2425

2526
namespace paddle {
2627
namespace operators {
@@ -555,6 +556,11 @@ class LUGradKernel : public framework::OpKernel<T> {
555556

556557
framework::Tensor Pmat;
557558
Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, &Pmat, m, k);
559+
560+
using Context =
561+
typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
562+
auto& phi_dev_ctx = static_cast<const Context&>(dev_ctx);
563+
558564
if (m <= n) {
559565
if (k < n) {
560566
framework::Tensor U_complement, U_grad_complement, phi_complement,
@@ -605,8 +611,9 @@ class LUGradKernel : public framework::OpKernel<T> {
605611
framework::Tensor psi_principal, phi_mH, psi_tmp;
606612
Tensor_Conj<DeviceContext, T>(dev_ctx, phi, &phi_mH);
607613
phi_mH = helper.Transpose(phi_mH);
608-
triangular_solve<DeviceContext, T>(dev_ctx, U_narrow, phi_mH,
609-
&psi_principal, true, false, false);
614+
615+
phi::TriangularSolveKernel<T, Context>(
616+
phi_dev_ctx, U_narrow, phi_mH, true, false, false, &psi_principal);
610617

611618
Tensor_Conj<DeviceContext, T>(dev_ctx, psi_principal, &psi_principal);
612619
psi_principal = helper.Transpose(psi_principal);
@@ -620,8 +627,9 @@ class LUGradKernel : public framework::OpKernel<T> {
620627
SetValueCompute_dispatch<DeviceContext, T>(ctx, &psi, &psi_principal,
621628
&psi, axes, &slice_starts,
622629
&slice_ends, valuedims, xrank);
623-
triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, psi, &psi_tmp,
624-
true, false, true);
630+
631+
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, L_narrow_mH, psi,
632+
true, false, true, &psi_tmp);
625633

626634
auto mat_dim_p =
627635
phi::funcs::CreateMatrixDescriptor(Pmat.dims(), 0, false);
@@ -672,8 +680,10 @@ class LUGradKernel : public framework::OpKernel<T> {
672680
&psi, axes, &slice_starts,
673681
&slice_ends, valuedims, xrank);
674682
framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH;
675-
triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, phi,
676-
&psi_principal, true, false, true);
683+
684+
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, L_narrow_mH, phi,
685+
true, false, true, &psi_principal);
686+
677687
slice_starts[0] = 0;
678688
slice_starts[1] = 0;
679689
slice_ends[0] = k;
@@ -695,8 +705,8 @@ class LUGradKernel : public framework::OpKernel<T> {
695705
psi_tmp = helper.Transpose(psi_tmp);
696706

697707
Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
698-
triangular_solve<DeviceContext, T>(dev_ctx, U_narrow_mH, psi_tmp, &psi,
699-
true, false, false);
708+
phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, U_narrow_mH, psi_tmp,
709+
true, false, false, &psi);
700710
*dx = helper.Transpose(psi);
701711
}
702712
}

paddle/fluid/operators/math/matrix_solve.cc

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,45 +34,6 @@ class MatrixSolveFunctor<platform::CPUDeviceContext, T> {
3434
template class MatrixSolveFunctor<platform::CPUDeviceContext, float>;
3535
template class MatrixSolveFunctor<platform::CPUDeviceContext, double>;
3636

37-
template <typename T>
38-
class TriangularSolveFunctor<platform::CPUDeviceContext, T> {
39-
public:
40-
void operator()(const platform::CPUDeviceContext& context,
41-
const framework::Tensor* a, framework::Tensor* b, bool left,
42-
bool upper, bool transpose, bool unitriangular) {
43-
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
44-
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
45-
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
46-
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
47-
48-
const T* a_data = a->data<T>();
49-
T* b_data = b->mutable_data<T>(context.GetPlace());
50-
51-
int a_dim_size = a->dims().size();
52-
int b_dim_size = b->dims().size();
53-
54-
int M = static_cast<int>(b->dims()[b_dim_size - 2]);
55-
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
56-
auto lda = left ? std::max(1, M) : std::max(1, N);
57-
auto ldb = std::max(1, N);
58-
59-
int batch_size = 1;
60-
auto& a_dim = a->dims();
61-
for (int i = 0; i < a_dim_size - 2; i++) {
62-
batch_size *= a_dim[i];
63-
}
64-
65-
auto blas = phi::funcs::GetBlas<platform::CPUDeviceContext, T>(context);
66-
for (int i = 0; i < batch_size; i++) {
67-
blas.TRSM(side, uplo, transA, diag, M, N, T(1), a_data + i * M * M, lda,
68-
b_data + i * N * M, ldb);
69-
}
70-
}
71-
};
72-
73-
template class TriangularSolveFunctor<platform::CPUDeviceContext, float>;
74-
template class TriangularSolveFunctor<platform::CPUDeviceContext, double>;
75-
7637
} // namespace math
7738
} // namespace operators
7839
} // namespace paddle

paddle/fluid/operators/math/matrix_solve.cu.cc

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -161,67 +161,6 @@ class MatrixSolveFunctor<platform::CUDADeviceContext, T> {
161161
template class MatrixSolveFunctor<platform::CUDADeviceContext, float>;
162162
template class MatrixSolveFunctor<platform::CUDADeviceContext, double>;
163163

164-
template <typename T>
165-
class TriangularSolveFunctor<platform::CUDADeviceContext, T> {
166-
public:
167-
void operator()(const platform::CUDADeviceContext& context, const Tensor* a,
168-
Tensor* b, bool left, bool upper, bool transpose,
169-
bool unitriangular) {
170-
CBLAS_SIDE side = left ? CblasLeft : CblasRight;
171-
CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower;
172-
CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans;
173-
CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit;
174-
175-
const T* a_data = a->data<T>();
176-
T* b_data = b->mutable_data<T>(context.GetPlace());
177-
178-
int a_dim_size = a->dims().size();
179-
int b_dim_size = b->dims().size();
180-
181-
int M = static_cast<int>(b->dims()[b_dim_size - 2]);
182-
int N = static_cast<int>(b->dims()[b_dim_size - 1]);
183-
auto lda = left ? std::max(1, M) : std::max(1, N);
184-
auto ldb = std::max(1, N);
185-
186-
int batch_size = 1;
187-
auto& a_dim = a->dims();
188-
for (int i = 0; i < a_dim_size - 2; i++) {
189-
batch_size *= a_dim[i];
190-
}
191-
192-
auto blas = phi::funcs::GetBlas<platform::CUDADeviceContext, T>(context);
193-
if (batch_size <= 8 && M >= 64) {
194-
for (auto i = 0; i < batch_size; i++) {
195-
blas.TRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
196-
a_data + i * M * M, lda, b_data + i * N * M, ldb);
197-
}
198-
} else {
199-
std::vector<const T*> cpu_ptrs(batch_size * 2);
200-
for (int i = 0; i < batch_size; ++i) {
201-
cpu_ptrs[i] = a_data + i * M * M;
202-
cpu_ptrs[i + batch_size] = b_data + i * M * N;
203-
}
204-
205-
// Copy the addresses of A and tmp_b from host to device.
206-
memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
207-
memory::Alloc(context, cpu_ptrs.size() * sizeof(T*));
208-
memory::Copy(context.GetPlace(), tmp_gpu_ptrs_data->ptr(),
209-
platform::CPUPlace(), static_cast<void*>(cpu_ptrs.data()),
210-
cpu_ptrs.size() * sizeof(T*), context.stream());
211-
212-
const T** gpu_a_ptrs =
213-
reinterpret_cast<const T**>(tmp_gpu_ptrs_data->ptr());
214-
T** gpu_b_ptrs =
215-
reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()) + batch_size;
216-
blas.BatchedTRSM(side, uplo, transA, diag, M, N, static_cast<T>(1.0),
217-
gpu_a_ptrs, lda, gpu_b_ptrs, ldb, batch_size);
218-
}
219-
}
220-
};
221-
222-
template class TriangularSolveFunctor<platform::CUDADeviceContext, float>;
223-
template class TriangularSolveFunctor<platform::CUDADeviceContext, double>;
224-
225164
} // namespace math
226165
} // namespace operators
227166
} // namespace paddle

paddle/fluid/operators/math/matrix_solve.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,6 @@ class MatrixSolveFunctor {
117117
const framework::Tensor& b, framework::Tensor* out);
118118
};
119119

120-
template <typename DeviceContext, typename T>
121-
class TriangularSolveFunctor {
122-
public:
123-
void operator()(const DeviceContext& context, const framework::Tensor* a,
124-
framework::Tensor* b, bool left, bool upper, bool transpose,
125-
bool unitriangular);
126-
};
127-
128120
} // namespace math
129121
} // namespace operators
130122
} // namespace paddle

paddle/fluid/operators/triangular_solve_op.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/operators/triangular_solve_op.h"
1615
#include "paddle/fluid/framework/infershape_utils.h"
1716
#include "paddle/fluid/framework/op_registry.h"
18-
#include "paddle/fluid/operators/solve_op.h"
1917
#include "paddle/phi/infermeta/binary.h"
2018

2119
namespace paddle {

paddle/fluid/operators/triangular_solve_op.h

Lines changed: 0 additions & 64 deletions
This file was deleted.

paddle/phi/kernels/cpu/triangular_solve_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/triangular_solve_kernel.h"
16+
1617
#include "paddle/phi/backends/cpu/cpu_context.h"
1718
#include "paddle/phi/core/ddim.h"
1819
#include "paddle/phi/core/kernel_registry.h"

0 commit comments

Comments
 (0)