Skip to content

Commit 794fae0

Browse files
committed
fix conflict
2 parents 4a7e8ab + cf907fc commit 794fae0

File tree

73 files changed

+5153
-195
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+5153
-195
lines changed

.github/codecov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ coverage:
1010
threshold: 1% # Allow the coverage to drop by 1%, and posting a success status.
1111
patch:
1212
default:
13-
target: 80% # lines adjusted Coverage < 80% CI will fail
13+
target: 80% # lines adjusted Coverage < 80% CI will fail

applications/neural_search/recall/in_batch_negative/evaluate.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,29 @@
1313
# limitations under the License.
1414

1515
import argparse
16+
import time
1617

1718
import numpy as np
1819

19-
import time
20-
2120
parser = argparse.ArgumentParser()
22-
parser.add_argument("--similar_text_pair", type=str,
23-
default='', help="The full path of similar pair file")
24-
parser.add_argument("--recall_result_file", type=str,
25-
default='', help="The full path of recall result file")
26-
parser.add_argument("--recall_num", type=int, default=10,
27-
help="Most similar number of doc recalled from corpus per query")
21+
parser.add_argument(
22+
"--similar_text_pair",
23+
type=str,
24+
default="",
25+
help="The full path of similar pair file",
26+
)
27+
parser.add_argument(
28+
"--recall_result_file",
29+
type=str,
30+
default="",
31+
help="The full path of recall result file",
32+
)
33+
parser.add_argument(
34+
"--recall_num",
35+
type=int,
36+
default=10,
37+
help="Most similar number of doc recalled from corpus per query",
38+
)
2839

2940

3041
args = parser.parse_args()
@@ -62,17 +73,16 @@ def recall(rs, N=10):
6273
with open(args.recall_result_file, "r", encoding="utf-8") as f:
6374
relevance_labels = []
6475
for index, line in enumerate(f):
65-
66-
if index % args.recall_num == 0 and index != 0:
67-
rs.append(relevance_labels)
68-
relevance_labels = []
69-
7076
text, recalled_text, cosine_sim = line.rstrip().split("\t")
7177
if text2similar[text] == recalled_text:
7278
relevance_labels.append(1)
7379
else:
7480
relevance_labels.append(0)
7581

82+
if (index + 1) % args.recall_num == 0:
83+
rs.append(relevance_labels)
84+
relevance_labels = []
85+
7686
recall_N = []
7787
recall_num = [1, 5, 10, 20, 50]
7888
for topN in recall_num:

applications/neural_search/recall/simcse/evaluate.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,16 @@ def recall(rs, N=10):
5757
with open(args.recall_result_file, "r", encoding="utf-8") as f:
5858
relevance_labels = []
5959
for index, line in enumerate(f):
60-
61-
if index % args.recall_num == 0 and index != 0:
62-
rs.append(relevance_labels)
63-
relevance_labels = []
64-
6560
text, recalled_text, cosine_sim = line.rstrip().split("\t")
6661
if text2similar[text] == recalled_text:
6762
relevance_labels.append(1)
6863
else:
6964
relevance_labels.append(0)
7065

66+
if (index + 1) % args.recall_num == 0:
67+
rs.append(relevance_labels)
68+
relevance_labels = []
69+
7170
recall_N = []
7271
recall_num = [1, 5, 10, 20, 50]
7372
result = open("result.tsv", "a")

applications/question_answering/supervised_qa/faq_finance/evaluate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ def recall(rs, N=10):
5959
with open(args.recall_result_file, "r", encoding="utf-8") as f:
6060
relevance_labels = []
6161
for index, line in enumerate(f):
62-
63-
if index % args.recall_num == 0 and index != 0:
64-
rs.append(relevance_labels)
65-
relevance_labels = []
66-
6762
text, recalled_text, cosine_sim = line.rstrip().split("\t")
6863
if text2similar[text] == recalled_text:
6964
relevance_labels.append(1)
7065
else:
7166
relevance_labels.append(0)
67+
68+
if (index + 1) % args.recall_num == 0:
69+
rs.append(relevance_labels)
70+
relevance_labels = []
71+
7272
recall_N = []
7373
recall_num = [1, 5, 10]
7474
result = open("result.tsv", "a")

applications/question_answering/supervised_qa/faq_system/evaluate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,16 @@ def recall(rs, N=10):
5959
with open(args.recall_result_file, "r", encoding="utf-8") as f:
6060
relevance_labels = []
6161
for index, line in enumerate(f):
62-
63-
if index % args.recall_num == 0 and index != 0:
64-
rs.append(relevance_labels)
65-
relevance_labels = []
66-
6762
text, recalled_text, cosine_sim = line.rstrip().split("\t")
6863
if text2similar[text] == recalled_text:
6964
relevance_labels.append(1)
7065
else:
7166
relevance_labels.append(0)
67+
68+
if (index + 1) % args.recall_num == 0:
69+
rs.append(relevance_labels)
70+
relevance_labels = []
71+
7272
recall_N = []
7373
recall_num = [1, 5, 10]
7474
result = open("result.tsv", "a")

applications/text_classification/hierarchical/retrieval_based/evaluate.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,23 @@
1818
import numpy as np
1919

2020
parser = argparse.ArgumentParser()
21-
parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similar pair file")
22-
parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file")
2321
parser.add_argument(
24-
"--recall_num", type=int, default=10, help="Most similair number of doc recalled from corpus per query"
22+
"--similar_text_pair",
23+
type=str,
24+
default="",
25+
help="The full path of similar pair file",
26+
)
27+
parser.add_argument(
28+
"--recall_result_file",
29+
type=str,
30+
default="",
31+
help="The full path of recall result file",
32+
)
33+
parser.add_argument(
34+
"--recall_num",
35+
type=int,
36+
default=10,
37+
help="Most similair number of doc recalled from corpus per query",
2538
)
2639
args = parser.parse_args()
2740

@@ -57,17 +70,24 @@ def recall(rs, N=10):
5770
with open(args.recall_result_file, "r", encoding="utf-8") as f:
5871
relevance_labels = []
5972
for index, line in enumerate(f):
60-
61-
if index % args.recall_num == 0 and index != 0:
62-
rs.append(relevance_labels)
63-
relevance_labels = []
6473
text_arr = line.rstrip().split("\t")
65-
text_title, text_para, recalled_title, recalled_para, label, cosine_sim = text_arr
74+
(
75+
text_title,
76+
text_para,
77+
recalled_title,
78+
recalled_para,
79+
label,
80+
cosine_sim,
81+
) = text_arr
6682
if text2similar["\t".join([text_title, text_para])] == label:
6783
relevance_labels.append(1)
6884
else:
6985
relevance_labels.append(0)
7086

87+
if (index + 1) % args.recall_num == 0:
88+
rs.append(relevance_labels)
89+
relevance_labels = []
90+
7191
recall_N = []
7292
recall_num = [1, 5, 10, 20, 50]
7393
for topN in recall_num:

applications/text_classification/multi_class/retrieval_based/evaluate.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,23 @@
1818
import numpy as np
1919

2020
parser = argparse.ArgumentParser()
21-
parser.add_argument("--similar_text_pair", type=str, default="", help="The full path of similar pair file")
22-
parser.add_argument("--recall_result_file", type=str, default="", help="The full path of recall result file")
2321
parser.add_argument(
24-
"--recall_num", type=int, default=10, help="Most similar number of doc recalled from corpus per query"
22+
"--similar_text_pair",
23+
type=str,
24+
default="",
25+
help="The full path of similar pair file",
26+
)
27+
parser.add_argument(
28+
"--recall_result_file",
29+
type=str,
30+
default="",
31+
help="The full path of recall result file",
32+
)
33+
parser.add_argument(
34+
"--recall_num",
35+
type=int,
36+
default=10,
37+
help="Most similar number of doc recalled from corpus per query",
2538
)
2639
args = parser.parse_args()
2740

@@ -57,17 +70,24 @@ def recall(rs, N=10):
5770
with open(args.recall_result_file, "r", encoding="utf-8") as f:
5871
relevance_labels = []
5972
for index, line in enumerate(f):
60-
61-
if index % args.recall_num == 0 and index != 0:
62-
rs.append(relevance_labels)
63-
relevance_labels = []
6473
text_arr = line.rstrip().split("\t")
65-
text_title, text_para, recalled_title, recalled_para, label, cosine_sim = text_arr
74+
(
75+
text_title,
76+
text_para,
77+
recalled_title,
78+
recalled_para,
79+
label,
80+
cosine_sim,
81+
) = text_arr
6682
if text2similar["\t".join([text_title, text_para])] == label:
6783
relevance_labels.append(1)
6884
else:
6985
relevance_labels.append(0)
7086

87+
if (index + 1) % args.recall_num == 0:
88+
rs.append(relevance_labels)
89+
relevance_labels = []
90+
7191
recall_N = []
7292
recall_num = [1, 5, 10, 20, 50]
7393
for topN in recall_num:

csrc/generation/get_output.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <stdio.h>
16+
#include <string.h>
17+
#include <sys/ipc.h>
18+
#include <sys/msg.h>
19+
#include <sys/types.h>
20+
#include "paddle/extension.h"
21+
22+
#define MAX_BSZ 512
23+
24+
struct msgdata {
25+
long mtype;
26+
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
27+
};
28+
29+
void GetOutput(const paddle::Tensor& x,
30+
int64_t rank_id,
31+
bool wait_flag) {
32+
if (rank_id > 0) return;
33+
34+
static struct msgdata msg_rcv;
35+
36+
static key_t key = ftok("./", 1);
37+
38+
static int msgid = msgget(key, IPC_CREAT | 0666);
39+
40+
int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
41+
int ret = -1;
42+
if (!wait_flag) {
43+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
44+
} else {
45+
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
46+
}
47+
if(ret == -1)
48+
{
49+
// read none
50+
out_data[0] = -2;
51+
out_data[1] = 0;
52+
return;
53+
}
54+
55+
int bsz = msg_rcv.mtext[1];
56+
57+
for (int64_t i = 0; i < bsz + 2; i++) {
58+
out_data[i] = (int64_t)msg_rcv.mtext[i];
59+
}
60+
return;
61+
}
62+
63+
PD_BUILD_OP(get_output)
64+
.Inputs({"x"})
65+
.Attrs({"rank_id: int64_t",
66+
"wait_flag: bool"})
67+
.Outputs({"x_out"})
68+
.SetInplaceMap({{"x", "x_out"}})
69+
.SetKernelFn(PD_KERNEL(GetOutput));

0 commit comments

Comments
 (0)