Skip to content

Commit 6398c15

Browse files
authored
Merge pull request #2718 from dzhwinter/lr_state
"lr state serialization"
2 parents bc36852 + 45adbfc commit 6398c15

File tree

6 files changed

+68
-39
lines changed

6 files changed

+68
-39
lines changed

paddle/optimizer/adadelta_optimizer.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {
2727

2828
const char* AdadeltaOptimizer::SerializeState(int* state_len) {
2929
AdadeltaOptimizerState state;
30-
// TODO(zhihong) : add lr_policy serialization
3130
state.set_num_sample_passed(num_sample_passed_);
31+
std::string lr_str = this->lr_policy_->SerializeState(state_len);
32+
state.mutable_lr_state()->ParseFromString(lr_str);
3233

3334
TensorToProto(*parameter_, state.mutable_parameter());
3435
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
3536
TensorToProto(*accum_delta_, state.mutable_accum_delta());
3637
TensorToProto(*update_delta_, state.mutable_update_delta());
3738
auto str = state.SerializeAsString();
38-
*state_len = str.size();
39+
*state_len += str.size();
3940
return str.c_str();
4041
}
4142

4243
void AdadeltaOptimizer::DeserializeState(const std::string& str) {
4344
AdadeltaOptimizerState state;
4445
state.ParseFromString(str);
45-
// TODO(zhihong) : add lr_policy DeserializeState
46+
auto lr_state = state.lr_state();
47+
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
4648
num_sample_passed_ = state.num_sample_passed();
4749

4850
ProtoToTensor(state.parameter(), parameter_);

paddle/optimizer/adagrad_optimizer.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
1919
}
2020
const char* AdagradOptimizer::SerializeState(int* state_len) {
2121
AdagradOptimizerState state;
22-
// TODO(zhihong) : add lr_policy serialization
2322
state.set_num_sample_passed(num_sample_passed_);
23+
std::string lr_str = this->lr_policy_->SerializeState(state_len);
24+
state.mutable_lr_state()->ParseFromString(lr_str);
2425

2526
TensorToProto(*parameter_, state.mutable_parameter());
2627
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
2728
auto str = state.SerializeAsString();
28-
*state_len = str.size();
29+
*state_len += str.size();
2930
return str.c_str();
3031
}
3132

3233
void AdagradOptimizer::DeserializeState(const std::string& str) {
3334
AdagradOptimizerState state;
3435
state.ParseFromString(str);
35-
// TODO(zhihong) : add lr_policy DeserializeState
36+
auto lr_state = state.lr_state();
37+
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
38+
3639
num_sample_passed_ = state.num_sample_passed();
3740
ProtoToTensor(state.parameter(), parameter_);
3841
ProtoToTensor(state.accum_gradient(), accum_gradient_);

paddle/optimizer/adam_optimizer.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {
2424

2525
const char *AdamOptimizer::SerializeState(int *state_len) {
2626
AdamOptimizerState state;
27-
// TODO(zhihong) : add lr_policy serialization
27+
std::string lr_str = this->lr_policy_->SerializeState(state_len);
28+
state.mutable_lr_state()->ParseFromString(lr_str);
2829
state.set_num_sample_passed(num_sample_passed_);
30+
2931
TensorToProto(*parameter_, state.mutable_parameter());
3032
TensorToProto(*momentums_, state.mutable_momentums());
3133
TensorToProto(*velocitys_, state.mutable_velocitys());
3234
auto str = state.SerializeAsString();
33-
*state_len = str.size();
35+
*state_len += str.size();
3436
return str.c_str();
3537
}
3638

3739
void AdamOptimizer::DeserializeState(const std::string &str) {
3840
AdamOptimizerState state;
3941
state.ParseFromString(str);
40-
// TODO(zhihong) : add lr_policy DeserializeState
42+
auto lr_state = state.lr_state();
43+
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
4144
num_sample_passed_ = state.num_sample_passed();
4245

4346
ProtoToTensor(state.parameter(), parameter_);

paddle/optimizer/lr_policy.h

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,56 @@ class LrPolicy {
1717
// constant learning rate policy
1818
class ConstLr final : public LrPolicy {
1919
public:
20-
ConstLr(double lr) : learning_rate(lr){};
20+
ConstLr(double lr) : learning_rate_(lr){};
2121
double LearningRate(const uint64_t num_sample_passed) {
22-
return learning_rate;
22+
return learning_rate_;
23+
}
24+
const char *SerializeState(int *state_len) {
25+
LrPolicyState state;
26+
state.set_learning_rate(learning_rate_);
27+
auto str = state.SerializeAsString();
28+
*state_len = str.size();
29+
return str.c_str();
30+
}
31+
void DeserializeState(const std::string &str) {
32+
LrPolicyState state;
33+
state.ParseFromString(str);
34+
learning_rate_ = state.learning_rate();
2335
}
24-
const char *SerializeState(int *state_len) { return nullptr; }
25-
void DeserializeState(const std::string &state) {}
2636

2737
private:
28-
double learning_rate;
38+
double learning_rate_;
2939
};
3040

3141
class LinearLr final : public LrPolicy {
3242
public:
3343
LinearLr(double lr, double lr_decay_a, double lr_decay_b)
34-
: learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {}
44+
: learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {}
3545
double LearningRate(const uint64_t num_sample_passed) {
36-
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
46+
return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
47+
lr_decay_b_);
3748
}
3849
const char *SerializeState(int *state_len) {
39-
// TODO(zhihong) : add lr_policy serialization
40-
return nullptr;
50+
LrPolicyState state;
51+
state.set_learning_rate(learning_rate_);
52+
state.set_lr_decay_a(lr_decay_a_);
53+
state.set_lr_decay_b(lr_decay_b_);
54+
auto str = state.SerializeAsString();
55+
*state_len = str.size();
56+
return str.c_str();
4157
}
42-
void DeserializeState(const std::string &state) {
43-
// TODO(zhihong) : add lr_policy serialization
58+
void DeserializeState(const std::string &str) {
59+
LrPolicyState state;
60+
state.ParseFromString(str);
61+
learning_rate_ = state.learning_rate();
62+
lr_decay_a_ = state.lr_decay_a();
63+
lr_decay_b_ = state.lr_decay_b();
4464
}
4565

4666
private:
47-
double learning_rate;
48-
double lr_decay_a;
49-
double lr_decay_b;
67+
double learning_rate_;
68+
double lr_decay_a_;
69+
double lr_decay_b_;
5070
};
5171

5272
} // namespace optimizer

paddle/optimizer/sgd_optimizer.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
3030
const char *SGDOptimizer::SerializeState(int *state_len) {
3131
SGDOptimizerState state;
3232
state.set_num_sample_passed(num_sample_passed_);
33+
std::string lr_str = this->lr_policy_->SerializeState(state_len);
34+
state.mutable_lr_state()->ParseFromString(lr_str);
3335
TensorToProto(*parameter_, state.mutable_parameter());
3436
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
3537
auto str = state.SerializeAsString();
36-
*state_len = str.size();
38+
*state_len += str.size();
3739
return str.c_str();
3840
}
3941

4042
void SGDOptimizer::DeserializeState(const std::string &str) {
4143
SGDOptimizerState state;
4244
state.ParseFromString(str);
45+
auto lr_state = state.lr_state();
46+
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
4347
num_sample_passed_ = state.num_sample_passed();
4448
ProtoToTensor(state.parameter(), parameter_);
4549
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);

proto/OptimizerConfig.proto

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,15 @@ enum DataType {
7878
repeated bytes content = 2;
7979
}
8080

81+
message LrPolicyState {
82+
// learninRate Policy
83+
optional double learning_rate = 1 [default = 1.0];
84+
optional double lr_decay_a = 2;
85+
optional double lr_decay_b = 3;
86+
}
87+
8188
message SGDOptimizerState {
82-
// learning rate policy
83-
optional double learning_rate = 101;
84-
optional double lr_decay_a = 102;
85-
optional double lr_decay_b = 103;
89+
optional LrPolicyState lr_state = 101;
8690
optional double num_sample_passed = 104;
8791
// state
8892
optional TensorProto parameter = 1;
@@ -91,9 +95,7 @@ message SGDOptimizerState {
9195

9296
message AdadeltaOptimizerState {
9397
// learning rate policy
94-
optional double learning_rate = 101;
95-
optional double lr_decay_a = 102;
96-
optional double lr_decay_b = 103;
98+
optional LrPolicyState lr_state = 101;
9799
optional double num_sample_passed = 104;
98100
// state
99101
optional TensorProto parameter = 1;
@@ -102,22 +104,17 @@ message AdadeltaOptimizerState {
102104
optional TensorProto update_delta = 4;
103105
}
104106

107+
105108
message AdagradOptimizerState {
106-
// learning rate policy
107-
optional double learning_rate = 101;
108-
optional double lr_decay_a = 102;
109-
optional double lr_decay_b = 103;
109+
optional LrPolicyState lr_state = 101;
110110
optional double num_sample_passed = 104;
111111
// state
112112
optional TensorProto parameter = 1;
113113
optional TensorProto accum_gradient = 2;
114114
}
115115

116116
message AdamOptimizerState {
117-
// learning rate policy
118-
optional double learning_rate = 101;
119-
optional double lr_decay_a = 102;
120-
optional double lr_decay_b = 103;
117+
optional LrPolicyState lr_state = 101;
121118
optional double num_sample_passed = 104;
122119
// state
123120
optional TensorProto parameter = 1;

0 commit comments

Comments
 (0)