1414
1515#pragma  once
1616
17+ #include  " paddle/phi/common/amp_type_traits.h" 
1718#include  " paddle/phi/kernels/adadelta_kernel.h" 
1819#include  " paddle/phi/kernels/funcs/eigen/common.h" 
1920#include  " paddle/phi/kernels/funcs/eigen/eigen_function.h" 
@@ -26,40 +27,58 @@ void AdadeltaKernel(const Context& dev_ctx,
2627                    const  DenseTensor& grad,
2728                    const  DenseTensor& avg_squared_grad,
2829                    const  DenseTensor& avg_squared_update,
30+                     const  paddle::optional<DenseTensor>& master_param,
2931                    float  rho,
3032                    float  epsilon,
33+                     bool  multi_precision,
3134                    DenseTensor* param_out,
3235                    DenseTensor* avg_squared_grad_out,
33-                     DenseTensor* avg_squared_update_out) {
36+                     DenseTensor* avg_squared_update_out,
37+                     DenseTensor* master_param_outs) {
38+   using  MPDType = typename  phi::dtype::template  MPTypeTrait<T>::Type;
3439  dev_ctx.template  Alloc <T>(param_out);
35-   dev_ctx.template  Alloc <T >(avg_squared_grad_out);
36-   dev_ctx.template  Alloc <T >(avg_squared_update_out);
40+   dev_ctx.template  Alloc <MPDType >(avg_squared_grad_out);
41+   dev_ctx.template  Alloc <MPDType >(avg_squared_update_out);
3742
38-   T  rho_ = static_cast <T >(rho);
39-   T  epsilon_ = static_cast <T >(epsilon);
43+   MPDType  rho_ = static_cast <MPDType >(rho);
44+   MPDType  epsilon_ = static_cast <MPDType >(epsilon);
4045
4146  auto  eigen_param = EigenVector<T>::Flatten (param);
4247  auto  eigen_grad = EigenVector<T>::Flatten (grad);
4348  //  Squared gradient accumulator
44-   auto  eigen_avg_squared_grad = EigenVector<T >::Flatten (avg_squared_grad);
49+   auto  eigen_avg_squared_grad = EigenVector<MPDType >::Flatten (avg_squared_grad);
4550  //  Squared updates accumulator
46-   auto  eigen_avg_squared_update = EigenVector<T>::Flatten (avg_squared_update);
51+   auto  eigen_avg_squared_update =
52+       EigenVector<MPDType>::Flatten (avg_squared_update);
4753  auto  eigen_param_out = EigenVector<T>::Flatten (*param_out);
4854  auto  eigen_avg_squared_grad_out =
49-       EigenVector<T >::Flatten (*avg_squared_grad_out);
55+       EigenVector<MPDType >::Flatten (*avg_squared_grad_out);
5056  auto  eigen_avg_squared_update_out =
51-       EigenVector<T >::Flatten (*avg_squared_update_out);
57+       EigenVector<MPDType >::Flatten (*avg_squared_update_out);
5258  auto & place = *dev_ctx.eigen_device ();
5359
60+   auto  eigen_grad_cast = eigen_grad.template  cast <MPDType>();
61+ 
5462  eigen_avg_squared_grad_out.device (place) =
55-       rho_ * eigen_avg_squared_grad + (1  - rho_) * eigen_grad .square ();
63+       rho_ * eigen_avg_squared_grad + (1  - rho_) * eigen_grad_cast .square ();
5664  auto  update = -((eigen_avg_squared_update + epsilon_) /
5765                  (eigen_avg_squared_grad_out + epsilon_))
5866                     .sqrt () *
59-                 eigen_grad ;
67+                 eigen_grad_cast ;
6068  eigen_avg_squared_update_out.device (place) =
6169      rho_ * eigen_avg_squared_update + (1  - rho_) * update.square ();
62-   eigen_param_out.device (place) = eigen_param + update;
70+ 
71+   if  (multi_precision) {
72+     auto  eigen_master_param_out =
73+         EigenVector<MPDType>::Flatten (*master_param_outs);
74+     auto  eigen_master_param = EigenVector<MPDType>::Flatten (*master_param);
75+ 
76+     eigen_master_param_out.device (place) = eigen_master_param + update;
77+     eigen_param_out.device (place) =
78+         (eigen_param.template  cast <MPDType>() + update).template  cast <T>();
79+   } else  {
80+     eigen_param_out.device (place) = eigen_param + update.template  cast <T>();
81+   }
6382}
6483
6584}  //  namespace phi
0 commit comments