Skip to content

Commit 82ef3a3

Browse files
authored
Openai API migrate (#2765)
1 parent c70bb3d commit 82ef3a3

File tree

4 files changed

+102
-38
lines changed

4 files changed

+102
-38
lines changed

docs/openai_api.md

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,30 @@ pip install --upgrade openai
3939

4040
Then, interact with model vicuna:
4141
```python
42-
import openai
42+
from openai import OpenAI
4343
# to get proper authentication, make sure to use a valid key that's listed in
4444
# the --api-keys flag. if no flag value is provided, the `api_key` will be ignored.
45-
openai.api_key = "EMPTY"
46-
openai.api_base = "http://localhost:8000/v1"
45+
client = OpenAI(api_key="EMPTY", base_url="http://localhost:8000/v1", default_headers={"x-foo": "true"})
4746

4847
model = "vicuna-7b-v1.5"
4948
prompt = "Once upon a time"
5049

51-
# create a completion
52-
completion = openai.Completion.create(model=model, prompt=prompt, max_tokens=64)
50+
# create a completion (legacy)
51+
completion = client.completions.create(
52+
model=model,
53+
prompt=prompt
54+
)
5355
# print the completion
5456
print(prompt + completion.choices[0].text)
5557

5658
# create a chat completion
57-
completion = openai.ChatCompletion.create(
58-
model=model,
59-
messages=[{"role": "user", "content": "Hello! What is your name?"}]
59+
completion = client.chat.completions.create(
60+
model="vicuna-7b-v1.5",
61+
response_format={ "type": "json_object" },
62+
messages=[
63+
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
64+
{"role": "user", "content": "Who won the world series in 2020?"}
65+
]
6066
)
6167
# print the completion
6268
print(completion.choices[0].message.content)

fastchat/serve/openai_api_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ async def show_available_models():
375375
return ModelList(data=model_cards)
376376

377377

378+
@app.post("/v1chat/completions", dependencies=[Depends(check_api_key)])
378379
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
379380
async def create_chat_completion(request: ChatCompletionRequest):
380381
"""Creates a completion for the chat message"""

playground/test_embedding/test_sentence_similarity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from scipy.spatial.distance import cosine
88

99

10-
def get_embedding_from_api(word, model="vicuna-7b-v1.1"):
10+
def get_embedding_from_api(word, model="vicuna-7b-v1.5"):
1111
if "ada" in model:
1212
resp = openai.Embedding.create(
1313
model=model,

tests/test_openai_api.py

Lines changed: 86 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,53 @@
44
Launch:
55
python3 launch_openai_api_test_server.py
66
"""
7+
from distutils.version import LooseVersion
8+
import warnings
79

810
import openai
911

12+
try:
13+
from openai import OpenAI, AsyncOpenAI
14+
except ImportError:
15+
warnings.warn("openai<1.0 is deprecated")
16+
1017
from fastchat.utils import run_cmd
1118

1219
openai.api_key = "EMPTY" # Not support yet
1320
openai.api_base = "http://localhost:8000/v1"
1421

1522

1623
def test_list_models():
17-
model_list = openai.Model.list()
18-
names = [x["id"] for x in model_list["data"]]
24+
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
25+
model_list = openai.Model.list()
26+
else:
27+
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
28+
model_list = client.models.list()
29+
names = [x.id for x in model_list.data]
1930
return names
2031

2132

2233
def test_completion(model, logprob):
2334
prompt = "Once upon a time"
24-
completion = openai.Completion.create(
25-
model=model,
26-
prompt=prompt,
27-
logprobs=logprob,
28-
max_tokens=64,
29-
temperature=0,
30-
)
35+
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
36+
completion = openai.Completion.create(
37+
model=model,
38+
prompt=prompt,
39+
logprobs=logprob,
40+
max_tokens=64,
41+
temperature=0,
42+
)
43+
else:
44+
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
45+
# legacy
46+
completion = client.completions.create(
47+
model=model,
48+
prompt=prompt,
49+
logprobs=logprob,
50+
max_tokens=64,
51+
temperature=0,
52+
)
53+
3154
print(f"full text: {prompt + completion.choices[0].text}", flush=True)
3255
if completion.choices[0].logprobs is not None:
3356
print(
@@ -38,42 +61,76 @@ def test_completion(model, logprob):
3861

3962
def test_completion_stream(model):
4063
prompt = "Once upon a time"
41-
res = openai.Completion.create(
42-
model=model,
43-
prompt=prompt,
44-
max_tokens=64,
45-
stream=True,
46-
temperature=0,
47-
)
64+
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
65+
res = openai.Completion.create(
66+
model=model,
67+
prompt=prompt,
68+
max_tokens=64,
69+
stream=True,
70+
temperature=0,
71+
)
72+
else:
73+
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
74+
# legacy
75+
res = client.completions.create(
76+
model=model,
77+
prompt=prompt,
78+
max_tokens=64,
79+
stream=True,
80+
temperature=0,
81+
)
4882
print(prompt, end="")
4983
for chunk in res:
50-
content = chunk["choices"][0]["text"]
84+
content = chunk.choices[0].text
5185
print(content, end="", flush=True)
5286
print()
5387

5488

5589
def test_embedding(model):
56-
embedding = openai.Embedding.create(model=model, input="Hello world!")
57-
print(f"embedding len: {len(embedding['data'][0]['embedding'])}")
58-
print(f"embedding value[:5]: {embedding['data'][0]['embedding'][:5]}")
90+
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
91+
embedding = openai.Embedding.create(model=model, input="Hello world!")
92+
else:
93+
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
94+
embedding = client.embeddings.create(model=model, input="Hello world!")
95+
print(f"embedding len: {len(embedding.data[0].embedding)}")
96+
print(f"embedding value[:5]: {embedding.data[0].embedding[:5]}")
5997

6098

6199
def test_chat_completion(model):
62-
completion = openai.ChatCompletion.create(
63-
model=model,
64-
messages=[{"role": "user", "content": "Hello! What is your name?"}],
65-
temperature=0,
66-
)
100+
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
101+
completion = openai.ChatCompletion.create(
102+
model=model,
103+
messages=[{"role": "user", "content": "Hello! What is your name?"}],
104+
temperature=0,
105+
)
106+
else:
107+
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
108+
completion = client.chat.completions.create(
109+
model=model,
110+
messages=[{"role": "user", "content": "Hello! What is your name?"}],
111+
temperature=0,
112+
)
67113
print(completion.choices[0].message.content)
68114

69115

70116
def test_chat_completion_stream(model):
71117
messages = [{"role": "user", "content": "Hello! What is your name?"}]
72-
res = openai.ChatCompletion.create(
73-
model=model, messages=messages, stream=True, temperature=0
74-
)
118+
if LooseVersion(openai.__version__) < LooseVersion("1.0"):
119+
res = openai.ChatCompletion.create(
120+
model=model, messages=messages, stream=True, temperature=0
121+
)
122+
else:
123+
client = OpenAI(api_key=openai.api_key, base_url=openai.api_base)
124+
res = client.chat.completions.create(
125+
model=model, messages=messages, stream=True, temperature=0
126+
)
75127
for chunk in res:
76-
content = chunk["choices"][0]["delta"].get("content", "")
128+
try:
129+
content = chunk.choices[0].delta.content
130+
if content is None:
131+
content = ""
132+
except Exception as e:
133+
content = chunk.choices[0].delta.get("content", "")
77134
print(content, end="", flush=True)
78135
print()
79136

0 commit comments

Comments
 (0)