Skip to content

Commit 83bb3d5

Browse files
committed
update test data use.
1 parent 273ebaf commit 83bb3d5

File tree

3 files changed

+78
-63
lines changed

3 files changed

+78
-63
lines changed

examples/training_sup_cosent_en.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ def calc_similarity_scores(args, sents1, sents2, labels):
3939
return spearman
4040

4141

42+
def load_en_stsb_dataset(stsb_file):
43+
# Convert the dataset to a DataLoader ready for training
44+
logger.info("Read STSbenchmark dataset")
45+
train_samples = []
46+
valid_samples = []
47+
test_samples = []
48+
with gzip.open(stsb_file, 'rt', encoding='utf8') as f:
49+
reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
50+
for row in reader:
51+
score = float(row['score'])
52+
if row['split'] == 'dev':
53+
valid_samples.append((row['sentence1'], row['sentence2'], score))
54+
elif row['split'] == 'test':
55+
test_samples.append((row['sentence1'], row['sentence2'], score))
56+
else:
57+
score = int(score > 2.5)
58+
train_samples.append((row['sentence1'], row['sentence2'], score))
59+
return train_samples, valid_samples, test_samples
60+
61+
4262
def main():
4363
parser = argparse.ArgumentParser('CoSENT Text Matching task')
4464
parser.add_argument('--model_name', default='bert-base-uncased', type=str, help='name of transformers model')
@@ -56,27 +76,11 @@ def main():
5676
args = parser.parse_args()
5777
logger.info(args)
5878

59-
test_samples = []
79+
train_samples, valid_samples, test_samples = load_en_stsb_dataset(args.stsb_file)
80+
6081
if args.do_train:
6182
model = CosentModel(model_name_or_path=args.model_name, encoder_type=args.encoder_type,
6283
max_seq_length=args.max_seq_length)
63-
64-
# Convert the dataset to a DataLoader ready for training
65-
logger.info("Read STSbenchmark dataset")
66-
train_samples = []
67-
valid_samples = []
68-
test_samples = []
69-
with gzip.open(args.stsb_file, 'rt', encoding='utf8') as f:
70-
reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
71-
for row in reader:
72-
score = float(row['score'])
73-
if row['split'] == 'dev':
74-
valid_samples.append((row['sentence1'], row['sentence2'], score))
75-
elif row['split'] == 'test':
76-
test_samples.append((row['sentence1'], row['sentence2'], score))
77-
else:
78-
train_samples.append((row['sentence1'], score))
79-
train_samples.append((row['sentence2'], score))
8084
train_dataset = CosentTrainDataset(model.tokenizer, train_samples, args.max_seq_length)
8185
valid_dataset = CosentTestDataset(model.tokenizer, valid_samples, args.max_seq_length)
8286
model.train(train_dataset,
@@ -86,21 +90,20 @@ def main():
8690
batch_size=args.batch_size,
8791
lr=args.learning_rate)
8892
logger.info(f"Model saved to {args.output_dir}")
93+
8994
if args.do_predict:
9095
model = CosentModel(model_name_or_path=args.output_dir, encoder_type=args.encoder_type,
9196
max_seq_length=args.max_seq_length)
92-
test_data = test_samples
93-
9497
# Predict embeddings
9598
srcs = []
9699
trgs = []
97100
labels = []
98-
for terms in test_data:
101+
for terms in test_samples:
99102
src, trg, label = terms[0], terms[1], terms[2]
100103
srcs.append(src)
101104
trgs.append(trg)
102105
labels.append(label)
103-
logger.debug(f'{test_data[0]}')
106+
logger.debug(f'{test_samples[0]}')
104107
sentence_embeddings = model.encode(srcs)
105108
logger.debug(f"{type(sentence_embeddings)}, {sentence_embeddings.shape}, {sentence_embeddings[0].shape}")
106109
# Predict similarity scores

examples/training_sup_sentencebert_en.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,35 @@ def calc_similarity_scores(args, sents1, sents2, labels):
3939
return spearman
4040

4141

42+
def load_en_stsb_dataset(stsb_file):
43+
# Convert the dataset to a DataLoader ready for training
44+
logger.info("Read STSbenchmark dataset")
45+
train_samples = []
46+
valid_samples = []
47+
test_samples = []
48+
with gzip.open(stsb_file, 'rt', encoding='utf8') as f:
49+
reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
50+
for row in reader:
51+
score = float(row['score'])
52+
if row['split'] == 'dev':
53+
valid_samples.append((row['sentence1'], row['sentence2'], score))
54+
elif row['split'] == 'test':
55+
test_samples.append((row['sentence1'], row['sentence2'], score))
56+
else:
57+
score = int(score > 2.5)
58+
train_samples.append((row['sentence1'], row['sentence2'], score))
59+
return train_samples, valid_samples, test_samples
60+
61+
4262
def main():
4363
parser = argparse.ArgumentParser('SentenceBERT Text Matching task')
4464
parser.add_argument('--model_name', default='bert-base-uncased', type=str, help='name of transformers model')
4565
parser.add_argument('--stsb_file', default='data/English-STS-B/stsbenchmark.tsv.gz', type=str,
4666
help='Train data path')
4767
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
4868
parser.add_argument("--do_predict", action="store_true", help="Whether to run predict.")
49-
parser.add_argument('--output_dir', default='./outputs/STS-B-en-sentencebert', type=str, help='Model output directory')
69+
parser.add_argument('--output_dir', default='./outputs/STS-B-en-sentencebert', type=str,
70+
help='Model output directory')
5071
parser.add_argument('--max_seq_length', default=64, type=int, help='Max sequence length')
5172
parser.add_argument('--num_epochs', default=10, type=int, help='Number of training epochs')
5273
parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
@@ -56,26 +77,11 @@ def main():
5677
args = parser.parse_args()
5778
logger.info(args)
5879

59-
test_samples = []
80+
train_samples, valid_samples, test_samples = load_en_stsb_dataset(args.stsb_file)
81+
6082
if args.do_train:
6183
model = SentenceBertModel(model_name_or_path=args.model_name, encoder_type=args.encoder_type,
6284
max_seq_length=args.max_seq_length)
63-
64-
# Convert the dataset to a DataLoader ready for training
65-
logger.info("Read STSbenchmark dataset")
66-
train_samples = []
67-
valid_samples = []
68-
with gzip.open(args.stsb_file, 'rt', encoding='utf8') as f:
69-
reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
70-
for row in reader:
71-
score = float(row['score'])
72-
if row['split'] == 'dev':
73-
valid_samples.append((row['sentence1'], row['sentence2'], score))
74-
elif row['split'] == 'test':
75-
test_samples.append((row['sentence1'], row['sentence2'], score))
76-
else:
77-
score = int(score > 2.5)
78-
train_samples.append((row['sentence1'], row['sentence2'], score))
7985
train_dataset = SentenceBertTrainDataset(model.tokenizer, train_samples, args.max_seq_length)
8086
valid_dataset = SentenceBertTestDataset(model.tokenizer, valid_samples, args.max_seq_length)
8187
model.train(train_dataset,
@@ -85,21 +91,20 @@ def main():
8591
batch_size=args.batch_size,
8692
lr=args.learning_rate)
8793
logger.info(f"Model saved to {args.output_dir}")
94+
8895
if args.do_predict:
8996
model = SentenceBertModel(model_name_or_path=args.output_dir, encoder_type=args.encoder_type,
9097
max_seq_length=args.max_seq_length)
91-
test_data = test_samples
92-
9398
# Predict embeddings
9499
srcs = []
95100
trgs = []
96101
labels = []
97-
for terms in test_data:
102+
for terms in test_samples:
98103
src, trg, label = terms[0], terms[1], terms[2]
99104
srcs.append(src)
100105
trgs.append(trg)
101106
labels.append(label)
102-
logger.debug(f'{test_data[0]}')
107+
logger.debug(f'{test_samples[0]}')
103108
sentence_embeddings = model.encode(srcs)
104109
logger.debug(f"{type(sentence_embeddings)}, {sentence_embeddings.shape}, {sentence_embeddings[0].shape}")
105110
# Predict similarity scores

examples/training_unsup_cosent_en.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,26 @@ def calc_similarity_scores(args, sents1, sents2, labels):
3939
return spearman
4040

4141

42+
def load_en_stsb_dataset(stsb_file):
43+
# Convert the dataset to a DataLoader ready for training
44+
logger.info("Read STSbenchmark dataset")
45+
train_samples = []
46+
valid_samples = []
47+
test_samples = []
48+
with gzip.open(stsb_file, 'rt', encoding='utf8') as f:
49+
reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
50+
for row in reader:
51+
score = float(row['score'])
52+
if row['split'] == 'dev':
53+
valid_samples.append((row['sentence1'], row['sentence2'], score))
54+
elif row['split'] == 'test':
55+
test_samples.append((row['sentence1'], row['sentence2'], score))
56+
else:
57+
score = int(score > 2.5)
58+
train_samples.append((row['sentence1'], row['sentence2'], score))
59+
return train_samples, valid_samples, test_samples
60+
61+
4262
def main():
4363
parser = argparse.ArgumentParser('CoSENT Text Matching task')
4464
parser.add_argument('--model_name', default='bert-base-uncased', type=str, help='name of transformers model')
@@ -58,7 +78,8 @@ def main():
5878
args = parser.parse_args()
5979
logger.info(args)
6080

61-
test_samples = []
81+
_, valid_samples, test_samples = load_en_stsb_dataset(args.stsb_file)
82+
6283
if args.do_train:
6384
model = CosentModel(model_name_or_path=args.model_name, encoder_type=args.encoder_type,
6485
max_seq_length=args.max_seq_length)
@@ -82,43 +103,29 @@ def main():
82103
break
83104

84105
train_dataset = CosentTrainDataset(model.tokenizer, nli_train_samples, args.max_seq_length)
85-
86-
# Convert the dataset to a DataLoader ready for validation
87-
logger.info("Read STSbenchmark dev and test dataset")
88-
valid_samples = []
89-
test_samples = []
90-
with gzip.open(args.stsb_file, 'rt', encoding='utf8') as f:
91-
reader = csv.DictReader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
92-
for row in reader:
93-
score = float(row['score'])
94-
if row['split'] == 'dev':
95-
valid_samples.append((row['sentence1'], row['sentence2'], score))
96-
elif row['split'] == 'test':
97-
test_samples.append((row['sentence1'], row['sentence2'], score))
98-
99106
valid_dataset = CosentTestDataset(model.tokenizer, valid_samples, args.max_seq_length)
107+
100108
model.train(train_dataset,
101109
args.output_dir,
102110
eval_dataset=valid_dataset,
103111
num_epochs=args.num_epochs,
104112
batch_size=args.batch_size,
105113
lr=args.learning_rate)
106114
logger.info(f"Model saved to {args.output_dir}")
115+
107116
if args.do_predict:
108117
model = CosentModel(model_name_or_path=args.output_dir, encoder_type=args.encoder_type,
109118
max_seq_length=args.max_seq_length)
110-
test_data = test_samples
111-
112119
# Predict embeddings
113120
srcs = []
114121
trgs = []
115122
labels = []
116-
for terms in test_data:
123+
for terms in test_samples:
117124
src, trg, label = terms[0], terms[1], terms[2]
118125
srcs.append(src)
119126
trgs.append(trg)
120127
labels.append(label)
121-
logger.debug(f'{test_data[0]}')
128+
logger.debug(f'{test_samples[0]}')
122129
sentence_embeddings = model.encode(srcs)
123130
logger.debug(f"{type(sentence_embeddings)}, {sentence_embeddings.shape}, {sentence_embeddings[0].shape}")
124131
# Predict similarity scores

0 commit comments

Comments
 (0)