7
7
import requests
8
8
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
9
9
import openai # use the official client for correctness check
10
+ from huggingface_hub import snapshot_download # downloading lora to test lora requests
10
11
11
12
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
12
13
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
14
+ LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
13
15
14
16
pytestmark = pytest .mark .asyncio
15
17
@@ -54,7 +56,12 @@ def __del__(self):
54
56
55
57
56
58
@pytest .fixture (scope = "session" )
57
- def server ():
59
+ def zephyr_lora_files ():
60
+ return snapshot_download (repo_id = LORA_NAME )
61
+
62
+
63
+ @pytest .fixture (scope = "session" )
64
+ def server (zephyr_lora_files ):
58
65
ray .init ()
59
66
server_runner = ServerRunner .remote ([
60
67
"--model" ,
@@ -64,6 +71,17 @@ def server():
64
71
"--max-model-len" ,
65
72
"8192" ,
66
73
"--enforce-eager" ,
74
+ # lora config below
75
+ "--enable-lora" ,
76
+ "--lora-modules" ,
77
+ f"zephyr-lora={ zephyr_lora_files } " ,
78
+ f"zephyr-lora2={ zephyr_lora_files } " ,
79
+ "--max-lora-rank" ,
80
+ "64" ,
81
+ "--max-cpu-loras" ,
82
+ "2" ,
83
+ "--max-num-seqs" ,
84
+ "128"
67
85
])
68
86
ray .get (server_runner .ready .remote ())
69
87
yield server_runner
@@ -79,8 +97,25 @@ def client():
79
97
yield client
80
98
81
99
82
- async def test_single_completion (server , client : openai .AsyncOpenAI ):
83
- completion = await client .completions .create (model = MODEL_NAME ,
100
+ async def test_check_models (server , client : openai .AsyncOpenAI ):
101
+ models = await client .models .list ()
102
+ models = models .data
103
+ served_model = models [0 ]
104
+ lora_models = models [1 :]
105
+ assert served_model .id == MODEL_NAME
106
+ assert all (model .root == MODEL_NAME for model in models )
107
+ assert lora_models [0 ].id == "zephyr-lora"
108
+ assert lora_models [1 ].id == "zephyr-lora2"
109
+
110
+
111
+ @pytest .mark .parametrize (
112
+ # first test base model, then test loras
113
+ "model_name" ,
114
+ [MODEL_NAME , "zephyr-lora" , "zephyr-lora2" ],
115
+ )
116
+ async def test_single_completion (server , client : openai .AsyncOpenAI ,
117
+ model_name : str ):
118
+ completion = await client .completions .create (model = model_name ,
84
119
prompt = "Hello, my name is" ,
85
120
max_tokens = 5 ,
86
121
temperature = 0.0 )
@@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
104
139
completion .choices [0 ].text ) >= 5
105
140
106
141
107
- async def test_single_chat_session (server , client : openai .AsyncOpenAI ):
142
+ @pytest .mark .parametrize (
143
+ # just test 1 lora hereafter
144
+ "model_name" ,
145
+ [MODEL_NAME , "zephyr-lora" ],
146
+ )
147
+ async def test_single_chat_session (server , client : openai .AsyncOpenAI ,
148
+ model_name : str ):
108
149
messages = [{
109
150
"role" : "system" ,
110
151
"content" : "you are a helpful assistant"
@@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
115
156
116
157
# test single completion
117
158
chat_completion = await client .chat .completions .create (
118
- model = MODEL_NAME ,
159
+ model = model_name ,
119
160
messages = messages ,
120
161
max_tokens = 10 ,
121
162
)
@@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
139
180
assert message .content is not None and len (message .content ) >= 0
140
181
141
182
142
- async def test_completion_streaming (server , client : openai .AsyncOpenAI ):
183
+ @pytest .mark .parametrize (
184
+ # just test 1 lora hereafter
185
+ "model_name" ,
186
+ [MODEL_NAME , "zephyr-lora" ],
187
+ )
188
+ async def test_completion_streaming (server , client : openai .AsyncOpenAI ,
189
+ model_name : str ):
143
190
prompt = "What is an LLM?"
144
191
145
192
single_completion = await client .completions .create (
146
- model = MODEL_NAME ,
193
+ model = model_name ,
147
194
prompt = prompt ,
148
195
max_tokens = 5 ,
149
196
temperature = 0.0 ,
@@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
152
199
single_usage = single_completion .usage
153
200
154
201
stream = await client .completions .create (
155
- model = MODEL_NAME ,
202
+ model = model_name ,
156
203
prompt = prompt ,
157
204
max_tokens = 5 ,
158
205
temperature = 0.0 ,
@@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
166
213
assert "" .join (chunks ) == single_output
167
214
168
215
169
- async def test_chat_streaming (server , client : openai .AsyncOpenAI ):
216
+ @pytest .mark .parametrize (
217
+ # just test 1 lora hereafter
218
+ "model_name" ,
219
+ [MODEL_NAME , "zephyr-lora" ],
220
+ )
221
+ async def test_chat_streaming (server , client : openai .AsyncOpenAI ,
222
+ model_name : str ):
170
223
messages = [{
171
224
"role" : "system" ,
172
225
"content" : "you are a helpful assistant"
@@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
177
230
178
231
# test single completion
179
232
chat_completion = await client .chat .completions .create (
180
- model = MODEL_NAME ,
233
+ model = model_name ,
181
234
messages = messages ,
182
235
max_tokens = 10 ,
183
236
temperature = 0.0 ,
@@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
187
240
188
241
# test streaming
189
242
stream = await client .chat .completions .create (
190
- model = MODEL_NAME ,
243
+ model = model_name ,
191
244
messages = messages ,
192
245
max_tokens = 10 ,
193
246
temperature = 0.0 ,
@@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
204
257
assert "" .join (chunks ) == output
205
258
206
259
207
- async def test_batch_completions (server , client : openai .AsyncOpenAI ):
260
+ @pytest .mark .parametrize (
261
+ # just test 1 lora hereafter
262
+ "model_name" ,
263
+ [MODEL_NAME , "zephyr-lora" ],
264
+ )
265
+ async def test_batch_completions (server , client : openai .AsyncOpenAI ,
266
+ model_name : str ):
208
267
# test simple list
209
268
batch = await client .completions .create (
210
- model = MODEL_NAME ,
269
+ model = model_name ,
211
270
prompt = ["Hello, my name is" , "Hello, my name is" ],
212
271
max_tokens = 5 ,
213
272
temperature = 0.0 ,
@@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
217
276
218
277
# test n = 2
219
278
batch = await client .completions .create (
220
- model = MODEL_NAME ,
279
+ model = model_name ,
221
280
prompt = ["Hello, my name is" , "Hello, my name is" ],
222
281
n = 2 ,
223
282
max_tokens = 5 ,
@@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
236
295
237
296
# test streaming
238
297
batch = await client .completions .create (
239
- model = MODEL_NAME ,
298
+ model = model_name ,
240
299
prompt = ["Hello, my name is" , "Hello, my name is" ],
241
300
max_tokens = 5 ,
242
301
temperature = 0.0 ,
0 commit comments