Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
hooks:
- id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
- id: go-fmt
types: [go]
- id: gometalinter
types: [go]
- id: go-fmt
types:
- go
- id: gometalinter
types:
- go
17 changes: 13 additions & 4 deletions go/master/c/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package main
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1

Expand Down Expand Up @@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}

//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
Copy link
Contributor

@helinwang helinwang Jul 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does trainer know which pass is the training currently in? We need someway to get this information, maybe paddle_set_dataset could return this information, or creating a new function seems fine as well.

Btw, this function is not called from anywhere is this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trainer don't need to know the exact master side pass count, use local pass count will do the work. Trainer just needs to know when to break and when to wait.

Copy link
Contributor

@helinwang helinwang Jul 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Say if the job trains on pass 100, and one trainer gets killed and started again, should it set local pass to 0 and start from there (call get record and get error, increment pass number by 1, and call again until reaching pass 100)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. That's right, this just works, I'll refine this in another PR later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, sure.

c := get(client)
c.StartGetRecords(int(pass))
}

//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
Expand All @@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int

// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}
Expand Down
70 changes: 39 additions & 31 deletions go/master/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package master

import (
"os"
"sync"
"time"

"github.com/PaddlePaddle/Paddle/go/connection"
Expand All @@ -27,9 +26,9 @@ import (

// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan record
initChOnce sync.Once
conn *connection.Conn
ch chan record
bufSize int
}

type record struct {
Expand All @@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if bufSize <= 0 {
return nil
}

c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
c.bufSize = bufSize
return nil
}
}
Expand Down Expand Up @@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if err != nil {
return nil, err
}

}

c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil
}

func (c *Client) getRecords() {
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}

func (c *Client) getRecords(passID int) {
for {
t, err := c.getTask()
t, err := c.getTask(passID)
if err != nil {
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
if err.Error() == ErrPassBefore.Error() ||
err.Error() == ErrNoMoreAvailable.Error() ||
err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
}

for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
f, e := os.Open(chunk.Path)
if e != nil {
log.Errorln(e)
continue
}

Expand Down Expand Up @@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}

// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
}

// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
func (c *Client) getTask(passID int) (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
err := c.conn.Call("Service.GetTask", passID, &t)
return t, err
}

Expand All @@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})

r := <-c.ch
return r.r, r.err
}
Expand Down
60 changes: 32 additions & 28 deletions go/master/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}

server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}

mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)

Expand Down Expand Up @@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)

err = c.SetDataset([]string{path})
if err != nil {
panic(err)
Expand All @@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
task, cErr := c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}

_, err = c.getTask()
if err == nil {
// getting task before task finishes should return error
_, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}

err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}

err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}

tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
_, cErr = c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
}
tasks = append(tasks, task)

for _, task := range tasks {
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil {
t.Fatal(cErr)
}
}
}

for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i)
}
}
Loading