Skip to content
16 changes: 14 additions & 2 deletions go/pserver/optimizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,23 @@ func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer {
o.elementType = paramWithConfigs.Param.ElementType
p := paramWithConfigs.Param
c := paramWithConfigs.Config
s := paramWithConfigs.State
log.WithFields(log.Fields{
"ElementType": p.ElementType,
"ParamSize": len(p.Content),
"ConfigSize": len(c),
"StateSize": len(s),
}).Info("New Optimizer Created with config:")
var cbuffer unsafe.Pointer
cbuffer = C.malloc(C.size_t(len(p.Content)))
C.memcpy(cbuffer, unsafe.Pointer(&p.Content[0]), C.size_t(len(p.Content)))
var cstate unsafe.Pointer
if len(s) != 0 {
cstate = unsafe.Pointer(&s[0])
}

o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float),
(*C.char)(nullPtr), 0)
C.paddle_element_type(p.ElementType), cbuffer, C.int(len(p.Content)/C.sizeof_float), (*C.char)(cstate), C.int(len(s)))
return o
}

Expand All @@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return cArrayToSlice(buffer, int(buffer_len)*C.sizeof_float)
}

func (o *optimizer) GetStates() []byte {
var cbuffer *C.char
cbuffer_len := C.paddle_optimizer_get_state(o.opt, &cbuffer)
return cArrayToSlice(unsafe.Pointer(cbuffer), int(cbuffer_len))
}

func (o *optimizer) UpdateParameter(g Gradient) error {
if o.elementType != g.ElementType {
return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", g.Name, o.elementType, g.ElementType)
Expand Down
84 changes: 81 additions & 3 deletions go/pserver/service.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
package pserver

import (
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"errors"
"fmt"
"os"
"strconv"
"sync"
"time"

log "github.com/sirupsen/logrus"
)

// ElementType is the type of elements of a Parameter.
Expand All @@ -14,6 +24,10 @@ const (
Uninitialized = "pserver not fully initialized"
)

const (
checkpoint_path = "./checkpoints/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go's naming convention is camelCase, not snake_case.

checkpointPath need to be an argument (flag.String) passed to go/cmd/pserver program. Since the k8s will configure the path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

)

// Supported element types
const (
Int32 ElementType = iota
Expand All @@ -38,6 +52,7 @@ type Parameter struct {
type ParameterWithConfig struct {
Param Parameter
Config []byte // parameter configuration in Proto Buffer format
State []byte // parameter training state
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParameterWithConfig is the data sent from the trainer to the pserver. But State is saved by pserver, loaded by pserver, which is not related to trainer.
So State should not be part of this struct.

Maybe:

type checkpoint struct {
  Uuid      string
  Md5sum    string
  Timestamp string
  ParameterWithConfig // this is called embedded field
  State  []byte
}

embedded field: https://golang.org/ref/spec#Struct_types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split into info and data part. fix done.

}

// Gradient is the gradient of the parameter.
Expand All @@ -52,14 +67,34 @@ type Service struct {
optMap map[string]*optimizer
}

type checkpoint struct {
Uuid string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding json: "uuid" at the end of the line, so we can use Json.marshal to a format JSON.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with another PR. fix Done.

Md5sum string
Timestamp string
}

//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making content ...interface{} adds more complexity to the code: since it's interface type that we need to encode, we have to call gob.Register. It's harder to understand the code (people need to search for what does gob.Register do. And it's harder to maintain the code (whenever adds a new type for GetBytes to use, maintainer need to add gob.Register as well, it's hard to track.

Since here we only need to call GetBytes twice, and this function does not have much code. Maybe just put it inline? (and remove gob.Register)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix Done.


var buf bytes.Buffer
encoder := gob.NewEncoder(&buf)
err := encoder.Encode(content)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(idx int) (*Service, error) {
s := &Service{
idx: idx,
}
s.optMap = make(map[string]*optimizer)
s.optMap = make(map[string]*optimizer)
s.initialized = make(chan struct{})
gob.Register(ParameterWithConfig{})
gob.Register(checkpoint{})
return s, nil
}

Expand Down Expand Up @@ -142,8 +177,51 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {

// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save was intended for saving model. But now we no longer use pservers to save model. Can you rename save to checkpoint? Also, I think at least for the first implementation, checkpoint should not be exposed as a RPC method to the trainer, instead, pservers periodically checkpoints, so can you make this a private function: func (s *Service) checkpoint(path string) error? (note that we don't need parameter dummy *int anymore if it's not used for RPC).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix Done.

//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
<-s.initialized

// TODO
s.mu.Lock()
defer s.mu.Unlock()
var paramWithConfig ParameterWithConfig
for name, opt := range s.optMap {
paramWithConfig.Param.Name = name
paramWithConfig.Param.ElementType = opt.elementType
paramWithConfig.Param.Content = opt.GetWeights()
paramWithConfig.State = opt.GetStates()
content, err := GetBytes(paramWithConfig)
if err != nil {
log.Errorln(err)
}
ck := checkpoint{}
h := md5.New()
ck.Md5sum = hex.EncodeToString(h.Sum(content))
ck.Timestamp = time.Now().String()
ck.Uuid = checkpoint_path + strconv.Itoa(s.idx)
ckbytes, err := GetBytes(ck)
if err != nil {
log.Errorln(err)
}
// TODO: according design doc, need to save Uuid to etcd in json format
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the design doc mentioned using etcd to save checkpoint information as well. Maybe add a TODO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add etcd saving logic. fix done.

log.Infof("parameter checkpoint %s", ckbytes)

if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) {
log.Info("checkpoint not exists.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoint not exists. -> checkpoint does not exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done.

} else {
err = os.Remove(ck.Uuid)
log.Infof("remove %s", ck.Uuid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove %s -> checkpoint %s already exists, removing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done.

}
f, err := os.Create(ck.Uuid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will create so many files for each paramter. Following the design doc, we will only have one checkpoint file named UUID?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done.

defer f.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defer f.Close() will close when this function returns, not when the for loop goes to the next loop. And the for loop may be very long. So perhaps call f.Close() at the end of for loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix done.

if err != nil {
log.Errorln(err)
}
writer := bufio.NewWriter(f)
_, err = writer.Write(content)
writer.Flush()
if err != nil {
log.Errorln(err)
}
}
return nil
}
6 changes: 6 additions & 0 deletions go/pserver/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) {
if !reflect.DeepEqual(param1, p) {
t.FailNow()
}
var dummy int
s.Save("", &dummy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass in nil is fine: s.Save("", nil). I used s.Save("", &dummy) before but later realized that it's fine to pass in nil :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

func TestMultipleInit(t *testing.T) {
Expand Down Expand Up @@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) {

wg.Wait()
}

func TestCheckpointSpeed(t *testing.T) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speed can be tested with benchmark. Here is an example: https://dave.cheney.net/2013/06/30/how-to-write-benchmarks-in-go

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leave a TODO here, will be tested after reaching an agreement with @Yancey1989 's recover logic.

//TODO: test speed
}