@@ -39,14 +39,35 @@ def calc_similarity_scores(args, sents1, sents2, labels):
3939 return spearman
4040
4141
42+ def load_en_stsb_dataset (stsb_file ):
43+ # Convert the dataset to a DataLoader ready for training
44+ logger .info ("Read STSbenchmark dataset" )
45+ train_samples = []
46+ valid_samples = []
47+ test_samples = []
48+ with gzip .open (stsb_file , 'rt' , encoding = 'utf8' ) as f :
49+ reader = csv .DictReader (f , delimiter = '\t ' , quoting = csv .QUOTE_NONE )
50+ for row in reader :
51+ score = float (row ['score' ])
52+ if row ['split' ] == 'dev' :
53+ valid_samples .append ((row ['sentence1' ], row ['sentence2' ], score ))
54+ elif row ['split' ] == 'test' :
55+ test_samples .append ((row ['sentence1' ], row ['sentence2' ], score ))
56+ else :
57+ score = int (score > 2.5 )
58+ train_samples .append ((row ['sentence1' ], row ['sentence2' ], score ))
59+ return train_samples , valid_samples , test_samples
60+
61+
4262def main ():
4363 parser = argparse .ArgumentParser ('SentenceBERT Text Matching task' )
4464 parser .add_argument ('--model_name' , default = 'bert-base-uncased' , type = str , help = 'name of transformers model' )
4565 parser .add_argument ('--stsb_file' , default = 'data/English-STS-B/stsbenchmark.tsv.gz' , type = str ,
4666 help = 'Train data path' )
4767 parser .add_argument ("--do_train" , action = "store_true" , help = "Whether to run training." )
4868 parser .add_argument ("--do_predict" , action = "store_true" , help = "Whether to run predict." )
49- parser .add_argument ('--output_dir' , default = './outputs/STS-B-en-sentencebert' , type = str , help = 'Model output directory' )
69+ parser .add_argument ('--output_dir' , default = './outputs/STS-B-en-sentencebert' , type = str ,
70+ help = 'Model output directory' )
5071 parser .add_argument ('--max_seq_length' , default = 64 , type = int , help = 'Max sequence length' )
5172 parser .add_argument ('--num_epochs' , default = 10 , type = int , help = 'Number of training epochs' )
5273 parser .add_argument ('--batch_size' , default = 64 , type = int , help = 'Batch size' )
@@ -56,26 +77,11 @@ def main():
5677 args = parser .parse_args ()
5778 logger .info (args )
5879
59- test_samples = []
80+ train_samples , valid_samples , test_samples = load_en_stsb_dataset (args .stsb_file )
81+
6082 if args .do_train :
6183 model = SentenceBertModel (model_name_or_path = args .model_name , encoder_type = args .encoder_type ,
6284 max_seq_length = args .max_seq_length )
63-
64- # Convert the dataset to a DataLoader ready for training
65- logger .info ("Read STSbenchmark dataset" )
66- train_samples = []
67- valid_samples = []
68- with gzip .open (args .stsb_file , 'rt' , encoding = 'utf8' ) as f :
69- reader = csv .DictReader (f , delimiter = '\t ' , quoting = csv .QUOTE_NONE )
70- for row in reader :
71- score = float (row ['score' ])
72- if row ['split' ] == 'dev' :
73- valid_samples .append ((row ['sentence1' ], row ['sentence2' ], score ))
74- elif row ['split' ] == 'test' :
75- test_samples .append ((row ['sentence1' ], row ['sentence2' ], score ))
76- else :
77- score = int (score > 2.5 )
78- train_samples .append ((row ['sentence1' ], row ['sentence2' ], score ))
7985 train_dataset = SentenceBertTrainDataset (model .tokenizer , train_samples , args .max_seq_length )
8086 valid_dataset = SentenceBertTestDataset (model .tokenizer , valid_samples , args .max_seq_length )
8187 model .train (train_dataset ,
@@ -85,21 +91,20 @@ def main():
8591 batch_size = args .batch_size ,
8692 lr = args .learning_rate )
8793 logger .info (f"Model saved to { args .output_dir } " )
94+
8895 if args .do_predict :
8996 model = SentenceBertModel (model_name_or_path = args .output_dir , encoder_type = args .encoder_type ,
9097 max_seq_length = args .max_seq_length )
91- test_data = test_samples
92-
9398 # Predict embeddings
9499 srcs = []
95100 trgs = []
96101 labels = []
97- for terms in test_data :
102+ for terms in test_samples :
98103 src , trg , label = terms [0 ], terms [1 ], terms [2 ]
99104 srcs .append (src )
100105 trgs .append (trg )
101106 labels .append (label )
102- logger .debug (f'{ test_data [0 ]} ' )
107+ logger .debug (f'{ test_samples [0 ]} ' )
103108 sentence_embeddings = model .encode (srcs )
104109 logger .debug (f"{ type (sentence_embeddings )} , { sentence_embeddings .shape } , { sentence_embeddings [0 ].shape } " )
105110 # Predict similarity scores
0 commit comments