16
16
from sklearn .cluster import KMeans , AgglomerativeClustering
17
17
import torch
18
18
from tqdm import tqdm
19
+ from openai import OpenAI
19
20
20
21
from fastchat .utils import detect_language
21
22
@@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only):
46
47
line_texts = [
47
48
x ["content" ] for x in l ["conversation" ] if x ["role" ] == "user"
48
49
]
50
+ elif "turns" in l :
51
+ line_texts = l ["turns" ]
49
52
50
53
for text in line_texts :
51
54
text = text .strip ()
@@ -77,14 +80,26 @@ def read_texts(input_file, min_length, max_length, english_only):
77
80
78
81
79
82
def get_embeddings (texts , model_name , batch_size ):
80
- model = SentenceTransformer (model_name )
81
- embeddings = model .encode (
82
- texts ,
83
- batch_size = batch_size ,
84
- show_progress_bar = True ,
85
- device = "cuda" ,
86
- convert_to_tensor = True ,
87
- )
83
+ if model_name == "text-embedding-ada-002" :
84
+ client = OpenAI ()
85
+ texts = texts .tolist ()
86
+
87
+ embeddings = []
88
+ for i in tqdm (range (0 , len (texts ), batch_size )):
89
+ text = texts [i : i + batch_size ]
90
+ responses = client .embeddings .create (input = text , model = model_name ).data
91
+ embeddings .extend ([data .embedding for data in responses ])
92
+ embeddings = torch .tensor (embeddings )
93
+ else :
94
+ model = SentenceTransformer (model_name )
95
+ embeddings = model .encode (
96
+ texts ,
97
+ batch_size = batch_size ,
98
+ show_progress_bar = True ,
99
+ device = "cuda" ,
100
+ convert_to_tensor = True ,
101
+ )
102
+
88
103
embeddings = torch .nn .functional .normalize (embeddings , p = 2 , dim = 1 )
89
104
return embeddings .cpu ()
90
105
@@ -218,6 +233,8 @@ def get_cluster_info(texts, labels, topk_indices):
218
233
)
219
234
parser .add_argument ("--show-top-k" , type = int , default = 200 )
220
235
parser .add_argument ("--show-cut-off" , type = int , default = 512 )
236
+ parser .add_argument ("--save-embeddings" , action = "store_true" )
237
+ parser .add_argument ("--embeddings-file" , type = str , default = None )
221
238
args = parser .parse_args ()
222
239
223
240
num_clusters = args .num_clusters
@@ -229,7 +246,15 @@ def get_cluster_info(texts, labels, topk_indices):
229
246
)
230
247
print (f"#text: { len (texts )} " )
231
248
232
- embeddings = get_embeddings (texts , args .model , args .batch_size )
249
+ if args .embeddings_file is None :
250
+ embeddings = get_embeddings (texts , args .model , args .batch_size )
251
+ if args .save_embeddings :
252
+ # allow saving embedding to save time and money
253
+ torch .save (embeddings , "embeddings.pt" )
254
+ else :
255
+ embeddings = torch .load (args .embeddings_file )
256
+ print (f"embeddings shape: { embeddings .shape } " )
257
+
233
258
if args .cluster_alg == "kmeans" :
234
259
centers , labels = run_k_means (embeddings , num_clusters )
235
260
elif args .cluster_alg == "aggcls" :
@@ -249,7 +274,7 @@ def get_cluster_info(texts, labels, topk_indices):
249
274
with open (filename_prefix + "_topk.txt" , "w" ) as fout :
250
275
fout .write (topk_str )
251
276
252
- with open (filename_prefix + "_all.txt " , "w" ) as fout :
277
+ with open (filename_prefix + "_all.jsonl " , "w" ) as fout :
253
278
for i in range (len (centers )):
254
279
tmp_indices = labels == i
255
280
tmp_embeddings = embeddings [tmp_indices ]
0 commit comments