Skip to content

Commit 8163cb2

Browse files
support openai embedding for topic clustering (#2729)
1 parent e86e70d commit 8163cb2

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

fastchat/serve/monitor/summarize_cluster.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import argparse
77
import pickle
88

9+
import pandas as pd
10+
911
from fastchat.llm_judge.common import (
1012
chat_completion_openai,
1113
chat_completion_openai_azure,
@@ -74,3 +76,10 @@ def truncate_string(s, l):
7476
print()
7577
print(f"topics: {topics}")
7678
print(f"percentages: {percentages}")
79+
80+
# save the informations
81+
df = pd.DataFrame()
82+
df["topic"] = topics
83+
df["percentage"] = percentages
84+
85+
df.to_json(f"cluster_summary_{len(df)}.jsonl", lines=True, orient="records")

fastchat/serve/monitor/topic_clustering.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from sklearn.cluster import KMeans, AgglomerativeClustering
1717
import torch
1818
from tqdm import tqdm
19+
from openai import OpenAI
1920

2021
from fastchat.utils import detect_language
2122

@@ -46,6 +47,8 @@ def read_texts(input_file, min_length, max_length, english_only):
4647
line_texts = [
4748
x["content"] for x in l["conversation"] if x["role"] == "user"
4849
]
50+
elif "turns" in l:
51+
line_texts = l["turns"]
4952

5053
for text in line_texts:
5154
text = text.strip()
@@ -77,14 +80,26 @@ def read_texts(input_file, min_length, max_length, english_only):
7780

7881

7982
def get_embeddings(texts, model_name, batch_size):
80-
model = SentenceTransformer(model_name)
81-
embeddings = model.encode(
82-
texts,
83-
batch_size=batch_size,
84-
show_progress_bar=True,
85-
device="cuda",
86-
convert_to_tensor=True,
87-
)
83+
if model_name == "text-embedding-ada-002":
84+
client = OpenAI()
85+
texts = texts.tolist()
86+
87+
embeddings = []
88+
for i in tqdm(range(0, len(texts), batch_size)):
89+
text = texts[i : i + batch_size]
90+
responses = client.embeddings.create(input=text, model=model_name).data
91+
embeddings.extend([data.embedding for data in responses])
92+
embeddings = torch.tensor(embeddings)
93+
else:
94+
model = SentenceTransformer(model_name)
95+
embeddings = model.encode(
96+
texts,
97+
batch_size=batch_size,
98+
show_progress_bar=True,
99+
device="cuda",
100+
convert_to_tensor=True,
101+
)
102+
88103
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
89104
return embeddings.cpu()
90105

@@ -218,6 +233,8 @@ def get_cluster_info(texts, labels, topk_indices):
218233
)
219234
parser.add_argument("--show-top-k", type=int, default=200)
220235
parser.add_argument("--show-cut-off", type=int, default=512)
236+
parser.add_argument("--save-embeddings", action="store_true")
237+
parser.add_argument("--embeddings-file", type=str, default=None)
221238
args = parser.parse_args()
222239

223240
num_clusters = args.num_clusters
@@ -229,7 +246,15 @@ def get_cluster_info(texts, labels, topk_indices):
229246
)
230247
print(f"#text: {len(texts)}")
231248

232-
embeddings = get_embeddings(texts, args.model, args.batch_size)
249+
if args.embeddings_file is None:
250+
embeddings = get_embeddings(texts, args.model, args.batch_size)
251+
if args.save_embeddings:
252+
# allow saving embedding to save time and money
253+
torch.save(embeddings, "embeddings.pt")
254+
else:
255+
embeddings = torch.load(args.embeddings_file)
256+
print(f"embeddings shape: {embeddings.shape}")
257+
233258
if args.cluster_alg == "kmeans":
234259
centers, labels = run_k_means(embeddings, num_clusters)
235260
elif args.cluster_alg == "aggcls":
@@ -249,7 +274,7 @@ def get_cluster_info(texts, labels, topk_indices):
249274
with open(filename_prefix + "_topk.txt", "w") as fout:
250275
fout.write(topk_str)
251276

252-
with open(filename_prefix + "_all.txt", "w") as fout:
277+
with open(filename_prefix + "_all.jsonl", "w") as fout:
253278
for i in range(len(centers)):
254279
tmp_indices = labels == i
255280
tmp_embeddings = embeddings[tmp_indices]

0 commit comments

Comments
 (0)