Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
hooks:
- id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
- id: go-fmt
types: [go]
- id: gometalinter
types: [go]
- id: go-fmt
types:
- go
- id: gometalinter
types:
- go
17 changes: 12 additions & 5 deletions go/master/c/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,8 @@ package main
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1

typedef int paddle_master_client;
*/
import "C"
Expand Down Expand Up @@ -102,13 +100,19 @@ func paddle_release_master_client(client C.paddle_master_client) {
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
// call PassStart to init a new dataset iteration
err := c.PassStart()
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
}
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(i)*unsafe.Sizeof(*path)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
err = c.SetDataset(paths)
if err != nil {
log.Errorln(err)
return C.PADDLE_MASTER_ERROR
Expand All @@ -120,13 +124,16 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// return value:
// 0:ok
// -1:error
// -2:pass end
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFinishError.Error() || err.Error() == master.ErrNoMoreAvailableError.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}
Expand Down
57 changes: 45 additions & 12 deletions go/master/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,31 @@ func NewClient(addrCh <-chan string, bufSize int) *Client {
c.conn = connection.New()
c.ch = make(chan record, bufSize)
go c.monitorMaster(addrCh)
go c.getRecords()
// FIXME: async connection creation
time.Sleep(time.Second)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious why do we need to sleep one second?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment and Sleep still necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Connection is created in a go-routine monitorMaster, immediate use of the conn may due to error.

err := c.addClient()
if err != nil {
log.Errorln("init client(addClient) error:", err)
}
return c
}

func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
// getTask call.
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
if err.Error() == ErrAllTaskFinishError.Error() || err.Error() == ErrNoMoreAvailableError.Error() {
log.Infof("Got %v, stopping getRecords routine.", err)
c.ch <- record{nil, err}
return
}
log.Errorf("getTask error: %s", err)
}

for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
f, e := os.Open(chunk.Path)
if e != nil {
log.Errorln(e)
continue
}

Expand Down Expand Up @@ -116,12 +123,17 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}

// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
err := c.conn.Call("Service.SetDataset", globPaths, nil)
// start to getRecords go-routine before each pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment still necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, Done.

go c.getRecords()
Copy link
Contributor

@helinwang helinwang Jul 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议用

c.gotRecordsOnce.Do(func() {
  go c.getRecords()
})

reference: https://golang.org/pkg/sync/#Once.Do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果c.gotRecordsOnce作为client的成员,需要在更新pass的时候重新初始化一个once对象,否则c.getRecords()不能再次启动了。在下一个PR中Refine下额。我把这个放到一个issue里。

return err
}

// getTask gets a new task from the master server.
Expand All @@ -147,5 +159,26 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
r := <-c.ch
if r.err != nil && (r.err.Error() == ErrAllTaskFinishError.Error() || r.err.Error() == ErrNoMoreAvailableError.Error()) {
err := c.PassFinish()
if err != nil {
return nil, err
}
}
return r.r, r.err
}

func (c *Client) addClient() error {
return c.conn.Call("Service.AddClient", 0, nil)
}

// PassFinish set current pass to finish
func (c *Client) PassFinish() error {
err := c.conn.Call("Service.PassFinish", 0, nil)
return err
}

// PassStart reset pass counter.
func (c *Client) PassStart() error {
return c.conn.Call("Service.PassStart", 0, nil)
}
66 changes: 34 additions & 32 deletions go/master/client_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}

server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}

mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)

Expand Down Expand Up @@ -103,52 +103,54 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}

checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
task, cErr := c.getTask()
if cErr != nil && cErr.Error() != ErrNoMoreAvailableError.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}

_, err = c.getTask()
if err == nil {
// getting task before task finishes should return error
_, cErr := c.getTask()
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}

err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}

err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}

tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
_, cErr = c.getTask()
if cErr != nil && cErr.Error() != ErrNoMoreAvailableError.Error() && cErr.Error() != ErrAllTaskFinishError.Error() {
t.Fatalf("Should be ErrNoMoreAvailableError or ErrAllTaskFinishError: %s", cErr)
}
tasks = append(tasks, task)

for _, task := range tasks {
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil && cErr.Error() != ErrAllTaskFinishError.Error() {
t.Fatalf("Non-ErrAllTaskFinishError: %v, pass: %d\n", cErr, i)
}
}
}

for i := 0; i < 10; i++ {
// init pass data
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
checkOnePass(i)
}
}
Loading