1
1
# Copyright (c) Alibaba, Inc. and its affiliates.
2
+ import inspect
3
+ import logging
2
4
import time
3
5
from dataclasses import asdict
4
6
from http import HTTPStatus
7
+ from types import MethodType
5
8
from typing import List , Optional , Union
6
9
7
10
import json
8
11
from fastapi import FastAPI , Request
9
12
from fastapi .responses import JSONResponse , StreamingResponse
13
+ from modelscope import GenerationConfig
10
14
11
- from swift .utils import get_main , seed_everything
15
+ from swift .utils import get_logger , get_main , seed_everything
12
16
from .infer import merge_lora , prepare_model_template
13
17
from .utils import ChatCompletionResponse # noqa
14
18
from .utils import (ChatCompletionRequest , ChatCompletionResponseChoice ,
18
22
CompletionResponseChoice , CompletionResponseStreamChoice ,
19
23
CompletionStreamResponse , DeltaMessage , DeployArguments ,
20
24
Model , ModelList , UsageInfo , VllmGenerationConfig ,
21
- messages_to_history , prepare_vllm_engine_template ,
22
- random_uuid )
25
+ inference , inference_stream , messages_to_history ,
26
+ prepare_vllm_engine_template , random_uuid )
27
+
28
+ logger = get_logger ()
23
29
24
30
app = FastAPI ()
31
+ _args = None
25
32
model = None
26
33
llm_engine = None
27
34
template = None
@@ -35,15 +42,18 @@ def create_error_response(status_code: Union[int, str, HTTPStatus],
35
42
36
43
@app .get ('/v1/models' )
37
44
async def get_available_models ():
38
- global llm_engine
39
- return ModelList (data = [Model (id = llm_engine .model_type )])
45
+ global _args
46
+ return ModelList (data = [Model (id = _args .model_type )])
40
47
41
48
42
49
async def check_length (request : Union [ChatCompletionRequest ,
43
50
CompletionRequest ],
44
51
input_ids : List [int ]) -> Optional [str ]:
45
- global llm_engine
46
- max_model_len = llm_engine .model_config .max_model_len
52
+ global llm_engine , model , _args
53
+ if _args .infer_backend == 'vllm' :
54
+ max_model_len = llm_engine .model_config .max_model_len
55
+ else :
56
+ max_model_len = model .max_model_len
47
57
num_tokens = len (input_ids )
48
58
max_tokens = request .max_tokens
49
59
if max_tokens is None :
@@ -69,6 +79,13 @@ async def check_model(
69
79
return f'`{ request .model } ` is not in the model_list: `{ model_type_list } `.'
70
80
71
81
82
+ def is_generation_template (template_type : str ) -> bool :
83
+ if 'generation' in template_type :
84
+ return True
85
+ else :
86
+ return False
87
+
88
+
72
89
async def inference_vllm_async (request : Union [ChatCompletionRequest ,
73
90
CompletionRequest ],
74
91
raw_request : Request ):
@@ -78,7 +95,8 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
78
95
return create_error_response (HTTPStatus .BAD_REQUEST , error_msg )
79
96
80
97
if request .seed is not None :
81
- seed_everything (request .seed )
98
+ seed_everything (request .seed , verbose = False )
99
+ _request = {'model' : request .model }
82
100
if isinstance (request , ChatCompletionRequest ):
83
101
if is_generation_template (template .template_type ):
84
102
return create_error_response (
@@ -89,19 +107,26 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
89
107
example = messages_to_history (request .messages )
90
108
input_ids = template .encode (example )[0 ]['input_ids' ]
91
109
request_id = f'chatcmpl-{ random_uuid ()} '
110
+ _request ['messages' ] = request .messages
92
111
else :
93
112
if not is_generation_template (template .template_type ):
94
113
return create_error_response (
95
114
HTTPStatus .BAD_REQUEST ,
96
115
f'The chat template `{ template .template_type } ` corresponding to '
97
116
f'the model `{ llm_engine .model_type } ` is in chat format. '
98
117
'Please use the `chat.completions` API.' )
99
- input_ids = template .encode ({'query' : request .prompt })[0 ]['input_ids' ]
118
+ example = {'query' : request .prompt }
119
+ input_ids = template .encode (example )[0 ]['input_ids' ]
100
120
request_id = f'cmpl-{ random_uuid ()} '
121
+ _request ['prompt' ] = request .prompt
122
+
123
+ request_info = {'request_id' : request_id }
124
+ request_info .update (_request )
101
125
102
126
error_msg = await check_length (request , input_ids )
103
127
if error_msg is not None :
104
128
return create_error_response (HTTPStatus .BAD_REQUEST , error_msg )
129
+
105
130
kwargs = {'max_new_tokens' : request .max_tokens }
106
131
for key in [
107
132
'n' , 'stop' , 'best_of' , 'frequency_penalty' , 'length_penalty' ,
@@ -114,6 +139,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
114
139
kwargs [key ] = getattr (llm_engine .generation_config , key )
115
140
else :
116
141
kwargs [key ] = new_value
142
+
117
143
generation_config = VllmGenerationConfig (** kwargs )
118
144
if generation_config .use_beam_search is True and request .stream is True :
119
145
error_msg = 'Streaming generation does not support beam search.'
@@ -124,6 +150,10 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
124
150
if isinstance (template .suffix [- 1 ],
125
151
str ) and template .suffix [- 1 ] not in generation_config .stop :
126
152
generation_config .stop .append (template .suffix [- 1 ])
153
+ request_info ['generation_config' ] = generation_config
154
+ request_info .update ({'seed' : request .seed , 'stream' : request .stream })
155
+ logger .info (request_info )
156
+
127
157
created_time = int (time .time ())
128
158
result_generator = llm_engine .generate (None , generation_config , request_id ,
129
159
input_ids )
@@ -153,7 +183,7 @@ async def _generate_full():
153
183
finish_reason = output .finish_reason ,
154
184
)
155
185
choices .append (choice )
156
- return ChatCompletionResponse (
186
+ response = ChatCompletionResponse (
157
187
model = request .model ,
158
188
choices = choices ,
159
189
usage = usage_info ,
@@ -168,12 +198,13 @@ async def _generate_full():
168
198
finish_reason = output .finish_reason ,
169
199
)
170
200
choices .append (choice )
171
- return CompletionResponse (
201
+ response = CompletionResponse (
172
202
model = request .model ,
173
203
choices = choices ,
174
204
usage = usage_info ,
175
205
id = request_id ,
176
206
created = created_time )
207
+ return response
177
208
178
209
async def _generate_stream ():
179
210
print_idx_list = [0 ] * request .n
@@ -228,29 +259,221 @@ async def _generate_stream():
228
259
return await _generate_full ()
229
260
230
261
231
- def is_generation_template (template_type : str ) -> bool :
232
- if 'generation' in template_type :
233
- return True
262
+ class _GenerationConfig (GenerationConfig ):
263
+
264
+ def __repr__ (self ) -> str :
265
+ parameters = inspect .signature (self .to_json_string ).parameters
266
+ kwargs = {}
267
+ if 'ignore_metadata' in parameters :
268
+ kwargs ['ignore_metadata' ] = True
269
+ gen_kwargs = json .loads (self .to_json_string (** kwargs ))
270
+ gen_kwargs .pop ('transformers_version' , None )
271
+ return f'GenerationConfig({ gen_kwargs } )'
272
+
273
+
274
+ async def inference_pt_async (request : Union [ChatCompletionRequest ,
275
+ CompletionRequest ],
276
+ raw_request : Request ):
277
+ global model , template
278
+ error_msg = await check_model (request )
279
+ if error_msg is not None :
280
+ return create_error_response (HTTPStatus .BAD_REQUEST , error_msg )
281
+
282
+ if request .seed is not None :
283
+ seed_everything (request .seed , verbose = False )
284
+ _request = {'model' : request .model }
285
+ if isinstance (request , ChatCompletionRequest ):
286
+ if is_generation_template (template .template_type ):
287
+ return create_error_response (
288
+ HTTPStatus .BAD_REQUEST ,
289
+ f'The chat template `{ template .template_type } ` corresponding to '
290
+ f'the model `{ model .model_type } ` is in text generation format. '
291
+ 'Please use the `completions` API.' )
292
+ example = messages_to_history (request .messages )
293
+ input_ids = template .encode (example )[0 ]['input_ids' ]
294
+ request_id = f'chatcmpl-{ random_uuid ()} '
295
+ _request ['messages' ] = request .messages
234
296
else :
235
- return False
297
+ if not is_generation_template (template .template_type ):
298
+ return create_error_response (
299
+ HTTPStatus .BAD_REQUEST ,
300
+ f'The chat template `{ template .template_type } ` corresponding to '
301
+ f'the model `{ model .model_type } ` is in chat format. '
302
+ 'Please use the `chat.completions` API.' )
303
+ example = {'query' : request .prompt }
304
+ input_ids = template .encode (example )[0 ]['input_ids' ]
305
+ request_id = f'cmpl-{ random_uuid ()} '
306
+ _request ['prompt' ] = request .prompt
307
+
308
+ request_info = {'request_id' : request_id }
309
+ request_info .update (_request )
310
+
311
+ error_msg = await check_length (request , input_ids )
312
+ if error_msg is not None :
313
+ return create_error_response (HTTPStatus .BAD_REQUEST , error_msg )
314
+
315
+ kwargs = {'max_new_tokens' : request .max_tokens }
316
+ # not use: 'n', 'best_of', 'frequency_penalty', 'presence_penalty'
317
+ for key in ['length_penalty' , 'num_beams' ]:
318
+ kwargs [key ] = getattr (request , key )
319
+ for key in ['temperature' , 'top_k' , 'top_p' , 'repetition_penalty' ]:
320
+ new_value = getattr (request , key )
321
+ if new_value is None :
322
+ kwargs [key ] = getattr (model .generation_config , key )
323
+ else :
324
+ kwargs [key ] = new_value
325
+ if kwargs ['temperature' ] == 0 :
326
+ kwargs ['do_sample' ] = False
327
+ kwargs ['temperature' ] = 1
328
+ kwargs ['top_p' ] = 1
329
+ kwargs ['top_k' ] = 50
330
+ else :
331
+ kwargs ['do_sample' ] = True
332
+
333
+ generation_config = _GenerationConfig (** kwargs )
334
+ request_info ['generation_config' ] = generation_config
335
+ request_info .update ({
336
+ 'seed' : request .seed ,
337
+ 'stop' : request .stop ,
338
+ 'stream' : request .stream
339
+ })
340
+ logger .info (request_info )
341
+
342
+ created_time = int (time .time ())
343
+
344
+ async def _generate_full ():
345
+ generation_info = {}
346
+ response , _ = inference (
347
+ model ,
348
+ template ,
349
+ ** example ,
350
+ stop_words = request .stop ,
351
+ generation_config = generation_config ,
352
+ generation_info = generation_info )
353
+ num_prompt_tokens = generation_info ['num_prompt_tokens' ]
354
+ num_generated_tokens = generation_info ['num_generated_tokens' ]
355
+ usage_info = UsageInfo (
356
+ prompt_tokens = num_prompt_tokens ,
357
+ completion_tokens = num_generated_tokens ,
358
+ total_tokens = num_prompt_tokens + num_generated_tokens ,
359
+ )
360
+ if isinstance (request , ChatCompletionRequest ):
361
+ choices = [
362
+ ChatCompletionResponseChoice (
363
+ index = 0 ,
364
+ message = ChatMessage (role = 'assistant' , content = response ),
365
+ finish_reason = None ,
366
+ )
367
+ ]
368
+ response = ChatCompletionResponse (
369
+ model = request .model ,
370
+ choices = choices ,
371
+ usage = usage_info ,
372
+ id = request_id ,
373
+ created = created_time )
374
+ else :
375
+ choices = [
376
+ CompletionResponseChoice (
377
+ index = 0 ,
378
+ text = response ,
379
+ finish_reason = None ,
380
+ )
381
+ ]
382
+ response = CompletionResponse (
383
+ model = request .model ,
384
+ choices = choices ,
385
+ usage = usage_info ,
386
+ id = request_id ,
387
+ created = created_time )
388
+ return response
389
+
390
+ def _generate_stream ():
391
+ generation_info = {}
392
+ gen = inference_stream (
393
+ model ,
394
+ template ,
395
+ ** example ,
396
+ stop_words = request .stop ,
397
+ generation_config = generation_config ,
398
+ generation_info = generation_info )
399
+
400
+ print_idx = 0
401
+ for response , _ in gen :
402
+ num_prompt_tokens = generation_info ['num_prompt_tokens' ]
403
+ num_generated_tokens = generation_info ['num_generated_tokens' ]
404
+ usage_info = UsageInfo (
405
+ prompt_tokens = num_prompt_tokens ,
406
+ completion_tokens = num_generated_tokens ,
407
+ total_tokens = num_prompt_tokens + num_generated_tokens ,
408
+ )
409
+ if isinstance (request , ChatCompletionRequest ):
410
+ delta_text = response [print_idx :]
411
+ print_idx = len (response )
412
+ choices = [
413
+ ChatCompletionResponseStreamChoice (
414
+ index = 0 ,
415
+ delta = DeltaMessage (
416
+ role = 'assistant' , content = delta_text ),
417
+ finish_reason = None )
418
+ ]
419
+ resp = ChatCompletionStreamResponse (
420
+ model = request .model ,
421
+ choices = choices ,
422
+ usage = usage_info ,
423
+ id = request_id ,
424
+ created = created_time )
425
+ else :
426
+ delta_text = response [print_idx :]
427
+ print_idx = len (response )
428
+ choices = [
429
+ CompletionResponseStreamChoice (
430
+ index = 0 , text = delta_text , finish_reason = None )
431
+ ]
432
+ resp = CompletionStreamResponse (
433
+ model = request .model ,
434
+ choices = choices ,
435
+ usage = usage_info ,
436
+ id = request_id ,
437
+ created = created_time )
438
+ yield f'data:{ json .dumps (asdict (resp ), ensure_ascii = False )} \n \n '
439
+ yield 'data:[DONE]\n \n '
440
+
441
+ if request .stream :
442
+ return StreamingResponse (_generate_stream ())
443
+ else :
444
+ return await _generate_full ()
236
445
237
446
238
447
@app .post ('/v1/chat/completions' )
239
448
async def create_chat_completion (
240
449
request : ChatCompletionRequest ,
241
450
raw_request : Request ) -> ChatCompletionResponse :
242
- return await inference_vllm_async (request , raw_request )
451
+ global _args
452
+ assert _args is not None
453
+ if _args .infer_backend == 'vllm' :
454
+ return await inference_vllm_async (request , raw_request )
455
+ else :
456
+ return await inference_pt_async (request , raw_request )
243
457
244
458
245
459
@app .post ('/v1/completions' )
246
460
async def create_completion (request : CompletionRequest ,
247
461
raw_request : Request ) -> CompletionResponse :
248
- return await inference_vllm_async (request , raw_request )
462
+ global _args
463
+ assert _args is not None
464
+ if _args .infer_backend == 'vllm' :
465
+ return await inference_vllm_async (request , raw_request )
466
+ else :
467
+ return await inference_pt_async (request , raw_request )
249
468
250
469
251
470
def llm_deploy (args : DeployArguments ) -> None :
471
+ logger_format = logging .Formatter (
472
+ '%(levelname)s: %(asctime)s %(filename)s:%(lineno)d] %(message)s' )
473
+ logger .handlers [0 ].setFormatter (logger_format )
252
474
import uvicorn
253
- global llm_engine , model , template
475
+ global llm_engine , model , template , _args
476
+ _args = args
254
477
if args .merge_lora :
255
478
merge_lora (args , device_map = 'cpu' )
256
479
if args .infer_backend == 'vllm' :
0 commit comments