Skip to content

Commit a8a2b7f

Browse files
Add multiprecision for adadelta op (#50131)
1 parent a1006b2 commit a8a2b7f

File tree

13 files changed

+607
-45
lines changed

13 files changed

+607
-45
lines changed

paddle/fluid/operators/optimizers/adadelta_op.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,17 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
3939
AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient");
4040
AddInput("AvgSquaredUpdate",
4141
"(Tensor) Input average of squared parameter updates");
42+
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();
4243

4344
AddOutput("ParamOut", "(Tensor) Output parameter");
4445
AddOutput("AvgSquaredGradOut",
4546
"(Tensor) Output average of squared gradient");
4647
AddOutput("AvgSquaredUpdateOut",
4748
"(Tensor) Output average of squared parameter updates");
49+
AddOutput("MasterParamOut",
50+
"The updated FP32 master weight for AMP. "
51+
"It shared memory with Input(MasterParam).")
52+
.AsDispensable();
4853

4954
AddAttr<float>("rho",
5055
"(float, default 0.95) Exponential decay rate "
@@ -54,6 +59,10 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
5459
"(float, default 1.0e-6) Constant for "
5560
"numerical stability")
5661
.SetDefault(1.0e-6f);
62+
AddAttr<bool>("multi_precision",
63+
"(bool, default false) "
64+
"Whether to use multi-precision during weight updating.")
65+
.SetDefault(false);
5766
AddComment(R"DOC(
5867
Adadelta Optimizer.
5968

paddle/fluid/pybind/eager_generator.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
206206
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
207207
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
208208
{"adagrad", {"Param", "Grad", "Moment", "LearningRate", "MasterParam"}},
209+
{"adadelta",
210+
{"Param", "Grad", "AvgSquaredGrad", "AvgSquaredUpdate", "MasterParam"}},
209211
{"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}},
210212
{"nce",
211213
{"Input",
@@ -311,6 +313,11 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
311313
"SavedMean",
312314
"SavedVariance",
313315
"ReserveSpace"}},
316+
{"adadelta",
317+
{"ParamOut",
318+
"AvgSquaredGradOut",
319+
"AvgSquaredUpdateOut",
320+
"MasterParamOut"}},
314321
{"unique", {"Out", "Index", "Indices", "Counts"}},
315322
{"unique_consecutive", {"Out", "Index", "Counts"}},
316323
{"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}},
@@ -400,7 +407,11 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
400407
"MeanGradOut",
401408
"MasterParamOut"}},
402409
{"ftrl", {"ParamOut", "SquaredAccumOut", "LinearAccumOut"}},
403-
{"adadelta", {"ParamOut", "AvgSquaredGradOut", "AvgSquaredUpdateOut"}},
410+
{"adadelta",
411+
{"ParamOut",
412+
"AvgSquaredGradOut",
413+
"AvgSquaredUpdateOut",
414+
"MasterParamOut"}},
404415
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
405416
{"adamax", {"ParamOut", "MomentOut", "InfNormOut"}},
406417
{"dpsgd", {"ParamOut"}},

paddle/phi/api/yaml/legacy_ops.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@
2020
data_type : x
2121

2222
- op : adadelta_
23-
args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, float rho, float epsilon)
24-
output : Tensor(param_out), Tensor(moment_out), Tensor(inf_norm_out)
23+
args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, Tensor master_param, float rho, float epsilon, bool multi_precision)
24+
output : Tensor(param_out), Tensor(moment_out), Tensor(inf_norm_out), Tensor(master_param_out)
2525
infer_meta :
2626
func : AdadeltaInferMeta
2727
kernel :
2828
func : adadelta
29-
inplace : (param -> param_out), (avg_squared_grad -> moment_out), (avg_squared_update -> inf_norm_out)
29+
data_type : param
30+
optional : master_param
31+
inplace : (param -> param_out), (avg_squared_grad -> moment_out), (avg_squared_update -> inf_norm_out), (master_param -> master_param_out)
3032

3133
- op : adagrad_
3234
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, Tensor master_param, float epsilon, bool multi_precision)

paddle/phi/infermeta/multiary.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,14 @@ void AdadeltaInferMeta(const MetaTensor& param,
3838
const MetaTensor& grad,
3939
const MetaTensor& avg_squared_grad,
4040
const MetaTensor& avg_squared_update,
41+
const MetaTensor& master_param,
4142
float rho,
4243
float epsilon,
44+
bool multi_precision,
4345
MetaTensor* param_out,
4446
MetaTensor* avg_squared_grad_out,
45-
MetaTensor* avg_squared_update_out) {
47+
MetaTensor* avg_squared_update_out,
48+
MetaTensor* master_param_out) {
4649
auto param_dims = param.dims();
4750
PADDLE_ENFORCE_EQ(
4851
param_dims,

paddle/phi/infermeta/multiary.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,14 @@ void AdadeltaInferMeta(const MetaTensor& param,
4343
const MetaTensor& grad,
4444
const MetaTensor& avg_squared_grad,
4545
const MetaTensor& avg_squared_update,
46+
const MetaTensor& master_param,
4647
float rho,
4748
float epsilon,
49+
bool multi_precision,
4850
MetaTensor* param_out,
4951
MetaTensor* avg_squared_grad_out,
50-
MetaTensor* avg_squared_update_out);
52+
MetaTensor* avg_squared_update_out,
53+
MetaTensor* master_param_outs);
5154

5255
void AdagradInferMeta(const MetaTensor& param,
5356
const MetaTensor& grad,

paddle/phi/kernels/adadelta_kernel.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ void AdadeltaKernel(const Context& dev_ctx,
2424
const DenseTensor& grad,
2525
const DenseTensor& avg_squared_grad,
2626
const DenseTensor& avg_squared_update,
27+
const paddle::optional<DenseTensor>& master_param,
2728
float rho,
2829
float epsilon,
30+
bool multi_precision,
2931
DenseTensor* param_out,
3032
DenseTensor* avg_squared_grad_out,
31-
DenseTensor* avg_squared_update_out);
33+
DenseTensor* avg_squared_update_out,
34+
DenseTensor* master_param_outs);
3235

3336
} // namespace phi

paddle/phi/kernels/gpu/adadelta_kernel.cu

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,10 @@
1818
#include "paddle/phi/core/kernel_registry.h"
1919
#include "paddle/phi/kernels/impl/adadelta_kernel_impl.h"
2020

21-
PD_REGISTER_KERNEL(
22-
adadelta, GPU, ALL_LAYOUT, phi::AdadeltaKernel, float, double) {}
21+
PD_REGISTER_KERNEL(adadelta,
22+
GPU,
23+
ALL_LAYOUT,
24+
phi::AdadeltaKernel,
25+
float,
26+
double,
27+
phi::dtype::float16) {}

paddle/phi/kernels/impl/adadelta_kernel_impl.h

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#pragma once
1616

17+
#include "paddle/phi/common/amp_type_traits.h"
1718
#include "paddle/phi/kernels/adadelta_kernel.h"
1819
#include "paddle/phi/kernels/funcs/eigen/common.h"
1920
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
@@ -26,40 +27,58 @@ void AdadeltaKernel(const Context& dev_ctx,
2627
const DenseTensor& grad,
2728
const DenseTensor& avg_squared_grad,
2829
const DenseTensor& avg_squared_update,
30+
const paddle::optional<DenseTensor>& master_param,
2931
float rho,
3032
float epsilon,
33+
bool multi_precision,
3134
DenseTensor* param_out,
3235
DenseTensor* avg_squared_grad_out,
33-
DenseTensor* avg_squared_update_out) {
36+
DenseTensor* avg_squared_update_out,
37+
DenseTensor* master_param_outs) {
38+
using MPDType = typename phi::dtype::template MPTypeTrait<T>::Type;
3439
dev_ctx.template Alloc<T>(param_out);
35-
dev_ctx.template Alloc<T>(avg_squared_grad_out);
36-
dev_ctx.template Alloc<T>(avg_squared_update_out);
40+
dev_ctx.template Alloc<MPDType>(avg_squared_grad_out);
41+
dev_ctx.template Alloc<MPDType>(avg_squared_update_out);
3742

38-
T rho_ = static_cast<T>(rho);
39-
T epsilon_ = static_cast<T>(epsilon);
43+
MPDType rho_ = static_cast<MPDType>(rho);
44+
MPDType epsilon_ = static_cast<MPDType>(epsilon);
4045

4146
auto eigen_param = EigenVector<T>::Flatten(param);
4247
auto eigen_grad = EigenVector<T>::Flatten(grad);
4348
// Squared gradient accumulator
44-
auto eigen_avg_squared_grad = EigenVector<T>::Flatten(avg_squared_grad);
49+
auto eigen_avg_squared_grad = EigenVector<MPDType>::Flatten(avg_squared_grad);
4550
// Squared updates accumulator
46-
auto eigen_avg_squared_update = EigenVector<T>::Flatten(avg_squared_update);
51+
auto eigen_avg_squared_update =
52+
EigenVector<MPDType>::Flatten(avg_squared_update);
4753
auto eigen_param_out = EigenVector<T>::Flatten(*param_out);
4854
auto eigen_avg_squared_grad_out =
49-
EigenVector<T>::Flatten(*avg_squared_grad_out);
55+
EigenVector<MPDType>::Flatten(*avg_squared_grad_out);
5056
auto eigen_avg_squared_update_out =
51-
EigenVector<T>::Flatten(*avg_squared_update_out);
57+
EigenVector<MPDType>::Flatten(*avg_squared_update_out);
5258
auto& place = *dev_ctx.eigen_device();
5359

60+
auto eigen_grad_cast = eigen_grad.template cast<MPDType>();
61+
5462
eigen_avg_squared_grad_out.device(place) =
55-
rho_ * eigen_avg_squared_grad + (1 - rho_) * eigen_grad.square();
63+
rho_ * eigen_avg_squared_grad + (1 - rho_) * eigen_grad_cast.square();
5664
auto update = -((eigen_avg_squared_update + epsilon_) /
5765
(eigen_avg_squared_grad_out + epsilon_))
5866
.sqrt() *
59-
eigen_grad;
67+
eigen_grad_cast;
6068
eigen_avg_squared_update_out.device(place) =
6169
rho_ * eigen_avg_squared_update + (1 - rho_) * update.square();
62-
eigen_param_out.device(place) = eigen_param + update;
70+
71+
if (multi_precision) {
72+
auto eigen_master_param_out =
73+
EigenVector<MPDType>::Flatten(*master_param_outs);
74+
auto eigen_master_param = EigenVector<MPDType>::Flatten(*master_param);
75+
76+
eigen_master_param_out.device(place) = eigen_master_param + update;
77+
eigen_param_out.device(place) =
78+
(eigen_param.template cast<MPDType>() + update).template cast<T>();
79+
} else {
80+
eigen_param_out.device(place) = eigen_param + update.template cast<T>();
81+
}
6382
}
6483

6584
} // namespace phi

paddle/phi/kernels/xpu/adadelta_kernel.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ void AdadeltaKernel(const Context& dev_ctx,
2525
const DenseTensor& grad,
2626
const DenseTensor& avg_squared_grad,
2727
const DenseTensor& avg_squared_update,
28+
const paddle::optional<DenseTensor>& master_param,
2829
float rho,
2930
float epsilon,
31+
bool multi_precision,
3032
DenseTensor* param_out,
3133
DenseTensor* avg_squared_grad_out,
32-
DenseTensor* avg_squared_update_out) {
34+
DenseTensor* avg_squared_update_out,
35+
DenseTensor* master_param_outs) {
3336
dev_ctx.template Alloc<T>(param_out);
3437
dev_ctx.template Alloc<T>(avg_squared_grad_out);
3538
dev_ctx.template Alloc<T>(avg_squared_update_out);
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/core/compat/op_utils.h"
16+
17+
namespace phi {
18+
19+
KernelSignature AdadeltaOpArgumentMapping(const ArgumentMappingContext& ctx) {
20+
if (ctx.IsDenseTensorInput("Grad")) {
21+
return KernelSignature(
22+
"adadelta",
23+
{"Param", "Grad", "AvgSquaredGrad", "AvgSquaredUpdate", "MasterParam"},
24+
{"rho", "epsilon", "multi_precision"},
25+
{"ParamOut",
26+
"AvgSquaredGradOut",
27+
"AvgSquaredUpdateOut",
28+
"MasterParamOut"});
29+
}
30+
31+
return KernelSignature("unregistered", {}, {}, {});
32+
}
33+
34+
} // namespace phi
35+
36+
PD_REGISTER_ARG_MAPPING_FN(adadelta, phi::AdadeltaOpArgumentMapping);

0 commit comments

Comments
 (0)