@@ -15,12 +15,13 @@ limitations under the License. */
1515#pragma once
1616
1717#include " paddle/fluid/framework/op_registry.h"
18+ #include " paddle/fluid/framework/phi_utils.h"
1819#include " paddle/fluid/operators/set_value_op.h"
1920#include " paddle/fluid/operators/svd_helper.h"
20- #include " paddle/fluid/operators/triangular_solve_op.h"
2121#include " paddle/fluid/operators/tril_triu_op.h"
2222#include " paddle/phi/kernels/funcs/lapack/lapack_function.h"
2323#include " paddle/phi/kernels/math_kernel.h"
24+ #include " paddle/phi/kernels/triangular_solve_kernel.h"
2425
2526namespace paddle {
2627namespace operators {
@@ -555,6 +556,11 @@ class LUGradKernel : public framework::OpKernel<T> {
555556
556557 framework::Tensor Pmat;
557558 Unpack_Pivot<DeviceContext, T>(dev_ctx, *P, &Pmat, m, k);
559+
560+ using Context =
561+ typename framework::ConvertToPhiContext<DeviceContext>::TYPE;
562+ auto & phi_dev_ctx = static_cast <const Context&>(dev_ctx);
563+
558564 if (m <= n) {
559565 if (k < n) {
560566 framework::Tensor U_complement, U_grad_complement, phi_complement,
@@ -605,8 +611,9 @@ class LUGradKernel : public framework::OpKernel<T> {
605611 framework::Tensor psi_principal, phi_mH, psi_tmp;
606612 Tensor_Conj<DeviceContext, T>(dev_ctx, phi, &phi_mH);
607613 phi_mH = helper.Transpose (phi_mH);
608- triangular_solve<DeviceContext, T>(dev_ctx, U_narrow, phi_mH,
609- &psi_principal, true , false , false );
614+
615+ phi::TriangularSolveKernel<T, Context>(
616+ phi_dev_ctx, U_narrow, phi_mH, true , false , false , &psi_principal);
610617
611618 Tensor_Conj<DeviceContext, T>(dev_ctx, psi_principal, &psi_principal);
612619 psi_principal = helper.Transpose (psi_principal);
@@ -620,8 +627,9 @@ class LUGradKernel : public framework::OpKernel<T> {
620627 SetValueCompute_dispatch<DeviceContext, T>(ctx, &psi, &psi_principal,
621628 &psi, axes, &slice_starts,
622629 &slice_ends, valuedims, xrank);
623- triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, psi, &psi_tmp,
624- true , false , true );
630+
631+ phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, L_narrow_mH, psi,
632+ true , false , true , &psi_tmp);
625633
626634 auto mat_dim_p =
627635 phi::funcs::CreateMatrixDescriptor (Pmat.dims (), 0 , false );
@@ -672,8 +680,10 @@ class LUGradKernel : public framework::OpKernel<T> {
672680 &psi, axes, &slice_starts,
673681 &slice_ends, valuedims, xrank);
674682 framework::Tensor psi_principal, phi_mH, psi_tmp, U_narrow_mH;
675- triangular_solve<DeviceContext, T>(dev_ctx, L_narrow_mH, phi,
676- &psi_principal, true , false , true );
683+
684+ phi::TriangularSolveKernel<T, Context>(phi_dev_ctx, L_narrow_mH, phi,
685+ true , false , true , &psi_principal);
686+
677687 slice_starts[0 ] = 0 ;
678688 slice_starts[1 ] = 0 ;
679689 slice_ends[0 ] = k;
@@ -695,8 +705,8 @@ class LUGradKernel : public framework::OpKernel<T> {
695705 psi_tmp = helper.Transpose (psi_tmp);
696706
697707 Tensor_Conj<DeviceContext, T>(dev_ctx, U_narrow, &U_narrow_mH);
698- triangular_solve<DeviceContext, T>(dev_ctx , U_narrow_mH, psi_tmp, &psi ,
699- true , false , false );
708+ phi::TriangularSolveKernel<T, Context>(phi_dev_ctx , U_narrow_mH, psi_tmp,
709+ true , false , false , &psi );
700710 *dx = helper.Transpose (psi);
701711 }
702712 }
0 commit comments