-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Pserver Save state #2716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pserver Save state #2716
Changes from 4 commits
5ef1425
f1330e2
e6c98e4
65afbe1
bfc3b43
40295b9
8426beb
774604c
2f2ffd9
87e7924
0ad7053
e8296ff
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,19 @@ | ||
| package pserver | ||
|
|
||
| import ( | ||
| "bufio" | ||
| "bytes" | ||
| "crypto/md5" | ||
| "encoding/gob" | ||
| "encoding/hex" | ||
| "errors" | ||
| "fmt" | ||
| "os" | ||
| "strconv" | ||
| "sync" | ||
| "time" | ||
|
|
||
| log "github.com/sirupsen/logrus" | ||
| ) | ||
|
|
||
| // ElementType is the type of elements of a Parameter. | ||
|
|
@@ -14,6 +24,10 @@ const ( | |
| Uninitialized = "pserver not fully initialized" | ||
| ) | ||
|
|
||
| const ( | ||
| checkpoint_path = "./checkpoints/" | ||
| ) | ||
|
|
||
| // Supported element types | ||
| const ( | ||
| Int32 ElementType = iota | ||
|
|
@@ -38,6 +52,7 @@ type Parameter struct { | |
| type ParameterWithConfig struct { | ||
| Param Parameter | ||
| Config []byte // parameter configuration in Proto Buffer format | ||
| State []byte // parameter training state | ||
|
||
| } | ||
|
|
||
| // Gradient is the gradient of the parameter. | ||
|
|
@@ -52,14 +67,34 @@ type Service struct { | |
| optMap map[string]*optimizer | ||
| } | ||
|
|
||
| type checkpoint struct { | ||
| Uuid string | ||
|
||
| Md5sum string | ||
| Timestamp string | ||
| } | ||
|
|
||
| //serialize ParameterWithConfig to byte stream | ||
| func GetBytes(content ...interface{}) ([]byte, error) { | ||
|
||
|
|
||
| var buf bytes.Buffer | ||
| encoder := gob.NewEncoder(&buf) | ||
| err := encoder.Encode(content) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| return buf.Bytes(), nil | ||
| } | ||
|
|
||
| // NewService creates a new service, will bypass etcd registration if no | ||
| // endpoints specified. | ||
| func NewService(idx int) (*Service, error) { | ||
| s := &Service{ | ||
| idx: idx, | ||
| } | ||
| s.optMap = make(map[string]*optimizer) | ||
| s.optMap = make(map[string]*optimizer) | ||
| s.initialized = make(chan struct{}) | ||
| gob.Register(ParameterWithConfig{}) | ||
| gob.Register(checkpoint{}) | ||
| return s, nil | ||
| } | ||
|
|
||
|
|
@@ -142,8 +177,51 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { | |
|
|
||
| // Save tells the parameter server to save parameters. | ||
| func (s *Service) Save(path string, dummy *int) error { | ||
|
||
| //FIXME: checkpoint is only used by pserver | ||
| // and has a constant path of */checkpoints/{pserver_idx}* | ||
| <-s.initialized | ||
|
|
||
| // TODO | ||
| s.mu.Lock() | ||
| defer s.mu.Unlock() | ||
| var paramWithConfig ParameterWithConfig | ||
| for name, opt := range s.optMap { | ||
| paramWithConfig.Param.Name = name | ||
| paramWithConfig.Param.ElementType = opt.elementType | ||
| paramWithConfig.Param.Content = opt.GetWeights() | ||
| paramWithConfig.State = opt.GetStates() | ||
| content, err := GetBytes(paramWithConfig) | ||
| if err != nil { | ||
| log.Errorln(err) | ||
| } | ||
| ck := checkpoint{} | ||
| h := md5.New() | ||
| ck.Md5sum = hex.EncodeToString(h.Sum(content)) | ||
| ck.Timestamp = time.Now().String() | ||
| ck.Uuid = checkpoint_path + strconv.Itoa(s.idx) | ||
| ckbytes, err := GetBytes(ck) | ||
| if err != nil { | ||
| log.Errorln(err) | ||
| } | ||
| // TODO: according design doc, need to save Uuid to etcd in json format | ||
| // {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx} | ||
|
||
| log.Infof("parameter checkpoint %s", ckbytes) | ||
|
|
||
| if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { | ||
| log.Info("checkpoint not exists.") | ||
|
||
| } else { | ||
| err = os.Remove(ck.Uuid) | ||
| log.Infof("remove %s", ck.Uuid) | ||
|
||
| } | ||
| f, err := os.Create(ck.Uuid) | ||
|
||
| defer f.Close() | ||
|
||
| if err != nil { | ||
| log.Errorln(err) | ||
| } | ||
| writer := bufio.NewWriter(f) | ||
| _, err = writer.Write(content) | ||
| writer.Flush() | ||
| if err != nil { | ||
| log.Errorln(err) | ||
| } | ||
| } | ||
| return nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) { | |
| if !reflect.DeepEqual(param1, p) { | ||
| t.FailNow() | ||
| } | ||
| var dummy int | ||
| s.Save("", &dummy) | ||
|
||
| } | ||
|
|
||
| func TestMultipleInit(t *testing.T) { | ||
|
|
@@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) { | |
|
|
||
| wg.Wait() | ||
| } | ||
|
|
||
| func TestCheckpointSpeed(t *testing.T) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Speed can be tested with benchmark. Here is an example: https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leave a TODO here, will be tested after reaching an agreement with @Yancey1989 's recover logic. |
||
| //TODO: test speed | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Go's naming convention is camelCase, not snake_case.
checkpointPath need to be an argument (
flag.String) passed to go/cmd/pserver program. Since the k8s will configure the path.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.