Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions paddle/optimizer/adadelta_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,24 @@ void AdadeltaOptimizer::Update(const Tensor* gradient) {

const char* AdadeltaOptimizer::SerializeState(int* state_len) {
AdadeltaOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);

TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
TensorToProto(*accum_delta_, state.mutable_accum_delta());
TensorToProto(*update_delta_, state.mutable_update_delta());
auto str = state.SerializeAsString();
*state_len = str.size();
*state_len += str.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why need to +=?

Copy link
Contributor Author

@dzhwinter dzhwinter Jul 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because get the lr_state serialization length here.

std::string lr_str = this->lr_policy_->SerializeState(state_len)

use = will overwrite the lr_state length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.

return str.c_str();
}

void AdadeltaOptimizer::DeserializeState(const std::string& str) {
AdadeltaOptimizerState state;
state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed();

ProtoToTensor(state.parameter(), parameter_);
Expand Down
9 changes: 6 additions & 3 deletions paddle/optimizer/adagrad_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,23 @@ void AdagradOptimizer::Update(const Tensor* gradient) {
}
const char* AdagradOptimizer::SerializeState(int* state_len) {
AdagradOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);

TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*accum_gradient_, state.mutable_accum_gradient());
auto str = state.SerializeAsString();
*state_len = str.size();
*state_len += str.size();
return str.c_str();
}

void AdagradOptimizer::DeserializeState(const std::string& str) {
AdagradOptimizerState state;
state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());

num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
ProtoToTensor(state.accum_gradient(), accum_gradient_);
Expand Down
9 changes: 6 additions & 3 deletions paddle/optimizer/adam_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,23 @@ void AdamOptimizer::Update(const Tensor *gradient) {

const char *AdamOptimizer::SerializeState(int *state_len) {
AdamOptimizerState state;
// TODO(zhihong) : add lr_policy serialization
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
state.set_num_sample_passed(num_sample_passed_);

TensorToProto(*parameter_, state.mutable_parameter());
TensorToProto(*momentums_, state.mutable_momentums());
TensorToProto(*velocitys_, state.mutable_velocitys());
auto str = state.SerializeAsString();
*state_len = str.size();
*state_len += str.size();
return str.c_str();
}

void AdamOptimizer::DeserializeState(const std::string &str) {
AdamOptimizerState state;
state.ParseFromString(str);
// TODO(zhihong) : add lr_policy DeserializeState
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed();

ProtoToTensor(state.parameter(), parameter_);
Expand Down
48 changes: 34 additions & 14 deletions paddle/optimizer/lr_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,56 @@ class LrPolicy {
// constant learning rate policy
class ConstLr final : public LrPolicy {
public:
ConstLr(double lr) : learning_rate(lr){};
ConstLr(double lr) : learning_rate_(lr){};
double LearningRate(const uint64_t num_sample_passed) {
return learning_rate;
return learning_rate_;
}
const char *SerializeState(int *state_len) {
LrPolicyState state;
state.set_learning_rate(learning_rate_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &str) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
}
const char *SerializeState(int *state_len) { return nullptr; }
void DeserializeState(const std::string &state) {}

private:
double learning_rate;
double learning_rate_;
};

class LinearLr final : public LrPolicy {
public:
LinearLr(double lr, double lr_decay_a, double lr_decay_b)
: learning_rate(lr), lr_decay_a(lr_decay_a), lr_decay_b(lr_decay_b) {}
: learning_rate_(lr), lr_decay_a_(lr_decay_a), lr_decay_b_(lr_decay_b) {}
double LearningRate(const uint64_t num_sample_passed) {
return std::max(learning_rate - lr_decay_a * num_sample_passed, lr_decay_b);
return std::max(learning_rate_ - lr_decay_a_ * num_sample_passed,
lr_decay_b_);
}
const char *SerializeState(int *state_len) {
// TODO(zhihong) : add lr_policy serialization
return nullptr;
LrPolicyState state;
state.set_learning_rate(learning_rate_);
state.set_lr_decay_a(lr_decay_a_);
state.set_lr_decay_b(lr_decay_b_);
auto str = state.SerializeAsString();
*state_len = str.size();
return str.c_str();
}
void DeserializeState(const std::string &state) {
// TODO(zhihong) : add lr_policy serialization
void DeserializeState(const std::string &str) {
LrPolicyState state;
state.ParseFromString(str);
learning_rate_ = state.learning_rate();
lr_decay_a_ = state.lr_decay_a();
lr_decay_b_ = state.lr_decay_b();
}

private:
double learning_rate;
double lr_decay_a;
double lr_decay_b;
double learning_rate_;
double lr_decay_a_;
double lr_decay_b_;
};

} // namespace optimizer
Expand Down
6 changes: 5 additions & 1 deletion paddle/optimizer/sgd_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,20 @@ void SGDOptimizer::Update(const Tensor *gradient) {
const char *SGDOptimizer::SerializeState(int *state_len) {
SGDOptimizerState state;
state.set_num_sample_passed(num_sample_passed_);
std::string lr_str = this->lr_policy_->SerializeState(state_len);
state.mutable_lr_state()->ParseFromString(lr_str);
TensorToProto(*parameter_, state.mutable_parameter());
if (momentum_ != 0.0) TensorToProto(*momentums_, state.mutable_momentums());
auto str = state.SerializeAsString();
*state_len = str.size();
*state_len += str.size();
return str.c_str();
}

void SGDOptimizer::DeserializeState(const std::string &str) {
SGDOptimizerState state;
state.ParseFromString(str);
auto lr_state = state.lr_state();
this->lr_policy_->DeserializeState(lr_state.SerializeAsString());
num_sample_passed_ = state.num_sample_passed();
ProtoToTensor(state.parameter(), parameter_);
if (momentum_ != 0.0) ProtoToTensor(state.parameter(), momentums_);
Expand Down
27 changes: 12 additions & 15 deletions proto/OptimizerConfig.proto
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ enum DataType {
repeated bytes content = 2;
}

message LrPolicyState {
// learninRate Policy
optional double learning_rate = 1 [default = 1.0];
optional double lr_decay_a = 2;
optional double lr_decay_b = 3;
}

message SGDOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
Expand All @@ -91,9 +95,7 @@ message SGDOptimizerState {

message AdadeltaOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
Expand All @@ -102,22 +104,17 @@ message AdadeltaOptimizerState {
optional TensorProto update_delta = 4;
}


message AdagradOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lr_state = 101;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not re-use protobuf ids to change the fields. Well, since we it's still under development, this is fine. When people are using this code, we need to consider the backward compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminding. That's true when we maintain a project.
But what is the best practice of field number? Making the field number auto increment by 1 seems not a smart idea. By the way, the field number is only used by protoc compiler to represent an entry of protobuf. Even though protobuf support indexed value by field number, but it should not be used in user's code. Did you ever run into such a code base?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the protobuf index is used by messages, it should never change for compatibility, for the old client or somewhere is still using the old .proto generated code.

  1. Do not change the current index
  2. Only append field not modify them.

optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
optional TensorProto accum_gradient = 2;
}

message AdamOptimizerState {
// learning rate policy
optional double learning_rate = 101;
optional double lr_decay_a = 102;
optional double lr_decay_b = 103;
optional LrPolicyState lr_state = 101;
optional double num_sample_passed = 104;
// state
optional TensorProto parameter = 1;
Expand Down