Skip to content

Commit f6d3543

Browse files
committed
feat: implement label_initial_prompt; better error handling
1 parent 4d4a24a commit f6d3543

File tree

4 files changed

+114
-13
lines changed

4 files changed

+114
-13
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@ The configuration file is saved in `config.json`.
2929

3030
## Plans
3131

32-
- [ ] Support task `label_initial_prompt`
33-
34-
This task is a bit hard to find for my language configuration, I'll keep finding it.
32+
- [x] ~~Support task `label_initial_prompt`~~
3533
- [ ] Support other language models (maybe)
3634

3735
There have been many excellent LLMs coming out recently, like GPT4All. I may try to support them in my free time(but no guarantees). if you have ideas, PRs are welcome!

README_zh.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929

3030
## 计划
3131

32-
- [ ] 支持任务 `label_initial_prompt`
33-
34-
目前中文的任务还是比较少的,尤其这个类别很少出现。等下次我遇到了再做支持,也欢迎大家 PR。
32+
- [x] ~~支持任务 `label_initial_prompt`~~
3533
- [ ] 支持其他语言模型(或许)
3634

3735
最近出现了许多优秀的 LLM,比如 GPT4All,我有空的话可能会尝试支持它们(但不能保证)。如果您有想法,欢迎提 PR!

model/model.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,20 @@ type OALabelPrompterReplyTask struct {
7878
UserId string `json:"userId"`
7979
}
8080

81+
type OALabelInitialPromptTask struct {
82+
Id string `json:"id"`
83+
Mode string `json:"mode"`
84+
Type string `json:"type"`
85+
Labels []OARandomTaskLabel `json:"labels"`
86+
Prompt string `json:"prompt"`
87+
MessageId string `json:"message_id"`
88+
Disposition string `json:"disposition"`
89+
Conversation OAConversation `json:"conversation"`
90+
ValidLabels []OALabel `json:"valid_labels"`
91+
MandatoryLabels []OALabel `json:"mandatory_labels"`
92+
UserId string `json:"userId"`
93+
}
94+
8195
type OAAssistantReplyTask struct {
8296
Id string `json:"id"`
8397
Type string `json:"type"`
@@ -162,3 +176,15 @@ type PostBodyRankAssistant struct {
162176
UpdateType string `json:"update_type"`
163177
Content PostBodyRankAssistantContent `json:"content"`
164178
}
179+
180+
type PostBodyLabelInitialPromptContent struct {
181+
Labels map[OALabel]float32 `json:"labels"`
182+
MessageId string `json:"message_id"`
183+
Text string `json:"text"`
184+
}
185+
type PostBodyLabelInitialPrompt struct {
186+
Id string `json:"id"`
187+
Lang string `json:"lang"`
188+
UpdateType string `json:"update_type"`
189+
Content PostBodyLabelInitialPromptContent `json:"content"`
190+
}

open-assistant.go

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ func RefreshCookie() error {
4040
return err
4141
}
4242
cs := resp.Header()["Set-Cookie"]
43-
if len(cs) == 0 {
44-
return fmt.Errorf("no cookie")
43+
if len(cs) <= 1 {
44+
return fmt.Errorf("no cookie. Please consider going to https://open-assistant.io , login, and update your cookie in config.json")
4545
}
4646
c := strings.Join(cs, "")
4747
return model.UpdateCookie(c)
@@ -150,7 +150,7 @@ func LabelPrompterReply(id string, task model.OALabelPrompterReplyTask) error {
150150
text += fmt.Sprintf("User: %s", task.Reply)
151151
}
152152

153-
t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given conversations between a user and the model, and you need to label the user's last reply and return a JSON string.You should evaluate the conversations based on the following criteria:
153+
t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given conversations between a user and the model, and you need to label the user's last reply and return a JSON string. You should evaluate the conversations based on the following criteria:
154154
- Spam: 0/1, whether the conversation contains spam / ads / porn / politics / etc.
155155
- Not Appropriate: 0/1, whether the response is reasonable for the user's question
156156
- pii: 0/1
@@ -206,6 +206,67 @@ You must return a JSON string, DO NOT include any other characters, DO NOT expla
206206
return nil
207207
}
208208

209+
func LabelInitialPrompt(id string, task model.OALabelInitialPromptTask) error {
210+
logx.Infof("LabelInitialPrompt")
211+
text := fmt.Sprintf("Prompt: %s\n\nUser's language code: %s", task.Prompt, model.Conf.Language)
212+
213+
t, err := Complete(text, `You are a powerful fine-tuner of Open Assistant, an open-source LLM. You will be given a prompt from the user, and you need to label it and return a JSON string. You should evaluate the prompt based on the following criteria:
214+
- Spam: 0/1, whether the message contains spam / ads / porn / politics / etc.
215+
- Not Appropriate: 0/1, whether the message is offensive / not respectful
216+
- pii: 0/1
217+
- Hate Speech: 0/1, whether the prompt is aggressive / not respectful
218+
- Sexual Content: 0/1
219+
- Quality: 0-1, step 0.25, how well the response is written respecting grammar, spelling, use of words, etc.
220+
- Lang Mismatch: 0-1, step 0.25, whether the prompt is in the same language as the user's language
221+
- Creativity: 0-1, step 0.25, how less is the prompt
222+
- Humor: 0-1, step 0.25
223+
- Toxicity: 0-1, step 0.25, how aggressive is the prompt
224+
- Violence: 0-1, step 0.25
225+
226+
You must return a JSON string, DO NOT include any other characters, DO NOT explain. Use snake_case for the keys.`)
227+
228+
if err != nil {
229+
return err
230+
}
231+
232+
logx.Infof("LabelInitialPrompt: %s", t)
233+
labels, err := GetLabelsFromChatGPT(t)
234+
if err != nil {
235+
return err
236+
}
237+
238+
body := model.PostBodyLabelInitialPrompt{
239+
Id: id,
240+
Lang: model.Conf.Language,
241+
UpdateType: "text_labels",
242+
Content: model.PostBodyLabelInitialPromptContent{
243+
Labels: labels,
244+
Text: "unused?",
245+
MessageId: task.MessageId,
246+
},
247+
}
248+
249+
resp, err := rty.R().
250+
SetHeaders(map[string]string{
251+
"Cookie": model.Conf.OaCookie,
252+
"Content-Type": "application/json",
253+
}).
254+
SetBody(body).
255+
Post("https://open-assistant.io/api/update_task")
256+
257+
if err != nil {
258+
return err
259+
}
260+
261+
respStr := string(resp.Body())
262+
if respStr == "" {
263+
logx.Infof("LabelInitialPrompt: OK!")
264+
} else {
265+
logx.Errorf("LabelInitialPrompt: %s", respStr)
266+
}
267+
return nil
268+
}
269+
209270
func AssistantReply(id string, task model.OAAssistantReplyTask) error {
210271
logx.Infof("AssistantReply")
211272
text := ""
@@ -374,9 +435,20 @@ func StartTask() error {
374435
}
375436

376437
var task model.OARandomTaskResponse
377-
err = jsonx.Unmarshal(resp.Body(), &task)
378-
if task.Task == nil {
379-
logx.Infof("GetTasks: get task failed: %s", string(resp.Body()))
438+
body := resp.Body()
439+
err = jsonx.Unmarshal(body, &task)
440+
if resp.StatusCode() == 403 {
441+
logx.Errorf("GetTasks: cookie may have expired (403 Forbidden)")
442+
logx.Errorf("Please login to https://open-assistant.io/dashboard and update your cookie in config.json")
443+
return nil
444+
} else if task.Task == nil {
445+
var _json map[string]interface{}
446+
_ = jsonx.Unmarshal(body, &_json)
447+
if strings.Contains(_json["message"].(string), "No tasks") {
448+
logx.Infof("GetTasks: no tasks at this time")
449+
}
450+
logx.Errorf("GetTasks: get task failed: %s", _json["message"])
451+
logx.Errorf("Please check your network, or login to https://open-assistant.io/dashboard and update your cookie in config.json")
380452
return nil
381453
}
382454

@@ -403,9 +475,16 @@ func StartTask() error {
403475
var ch model.OARankAssistantRepliesTask
404476
_ = jsonx.Unmarshal(j, &ch)
405477
return RankAssistantReplies(task.Id, ch)
478+
} else if t["type"] == "label_initial_prompt" {
479+
var ch model.OALabelInitialPromptTask
480+
_ = jsonx.Unmarshal(j, &ch)
481+
return LabelInitialPrompt(task.Id, ch)
406482
} else {
407483
logx.Infof("GetTasks: unknown task type: %s", t["type"])
408-
CancelTask(task.Id)
484+
err = CancelTask(task.Id)
485+
if err != nil {
486+
return fmt.Errorf("GetTasks: cancel task failed: %s", err)
487+
}
409488
}
410489

411490
return nil

0 commit comments

Comments
 (0)