Skip to content

Commit f08ff93

Browse files
authored
Support llm pt deploy (modelscope#467)
1 parent 3843e78 commit f08ff93

File tree

6 files changed

+255
-27
lines changed

6 files changed

+255
-27
lines changed

swift/llm/deploy.py

Lines changed: 241 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import inspect
3+
import logging
24
import time
35
from dataclasses import asdict
46
from http import HTTPStatus
7+
from types import MethodType
58
from typing import List, Optional, Union
69

710
import json
811
from fastapi import FastAPI, Request
912
from fastapi.responses import JSONResponse, StreamingResponse
13+
from modelscope import GenerationConfig
1014

11-
from swift.utils import get_main, seed_everything
15+
from swift.utils import get_logger, get_main, seed_everything
1216
from .infer import merge_lora, prepare_model_template
1317
from .utils import ChatCompletionResponse # noqa
1418
from .utils import (ChatCompletionRequest, ChatCompletionResponseChoice,
@@ -18,10 +22,13 @@
1822
CompletionResponseChoice, CompletionResponseStreamChoice,
1923
CompletionStreamResponse, DeltaMessage, DeployArguments,
2024
Model, ModelList, UsageInfo, VllmGenerationConfig,
21-
messages_to_history, prepare_vllm_engine_template,
22-
random_uuid)
25+
inference, inference_stream, messages_to_history,
26+
prepare_vllm_engine_template, random_uuid)
27+
28+
logger = get_logger()
2329

2430
app = FastAPI()
31+
_args = None
2532
model = None
2633
llm_engine = None
2734
template = None
@@ -35,15 +42,18 @@ def create_error_response(status_code: Union[int, str, HTTPStatus],
3542

3643
@app.get('/v1/models')
3744
async def get_available_models():
38-
global llm_engine
39-
return ModelList(data=[Model(id=llm_engine.model_type)])
45+
global _args
46+
return ModelList(data=[Model(id=_args.model_type)])
4047

4148

4249
async def check_length(request: Union[ChatCompletionRequest,
4350
CompletionRequest],
4451
input_ids: List[int]) -> Optional[str]:
45-
global llm_engine
46-
max_model_len = llm_engine.model_config.max_model_len
52+
global llm_engine, model, _args
53+
if _args.infer_backend == 'vllm':
54+
max_model_len = llm_engine.model_config.max_model_len
55+
else:
56+
max_model_len = model.max_model_len
4757
num_tokens = len(input_ids)
4858
max_tokens = request.max_tokens
4959
if max_tokens is None:
@@ -69,6 +79,13 @@ async def check_model(
6979
return f'`{request.model}` is not in the model_list: `{model_type_list}`.'
7080

7181

82+
def is_generation_template(template_type: str) -> bool:
83+
if 'generation' in template_type:
84+
return True
85+
else:
86+
return False
87+
88+
7289
async def inference_vllm_async(request: Union[ChatCompletionRequest,
7390
CompletionRequest],
7491
raw_request: Request):
@@ -78,7 +95,8 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
7895
return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
7996

8097
if request.seed is not None:
81-
seed_everything(request.seed)
98+
seed_everything(request.seed, verbose=False)
99+
_request = {'model': request.model}
82100
if isinstance(request, ChatCompletionRequest):
83101
if is_generation_template(template.template_type):
84102
return create_error_response(
@@ -89,19 +107,26 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
89107
example = messages_to_history(request.messages)
90108
input_ids = template.encode(example)[0]['input_ids']
91109
request_id = f'chatcmpl-{random_uuid()}'
110+
_request['messages'] = request.messages
92111
else:
93112
if not is_generation_template(template.template_type):
94113
return create_error_response(
95114
HTTPStatus.BAD_REQUEST,
96115
f'The chat template `{template.template_type}` corresponding to '
97116
f'the model `{llm_engine.model_type}` is in chat format. '
98117
'Please use the `chat.completions` API.')
99-
input_ids = template.encode({'query': request.prompt})[0]['input_ids']
118+
example = {'query': request.prompt}
119+
input_ids = template.encode(example)[0]['input_ids']
100120
request_id = f'cmpl-{random_uuid()}'
121+
_request['prompt'] = request.prompt
122+
123+
request_info = {'request_id': request_id}
124+
request_info.update(_request)
101125

102126
error_msg = await check_length(request, input_ids)
103127
if error_msg is not None:
104128
return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
129+
105130
kwargs = {'max_new_tokens': request.max_tokens}
106131
for key in [
107132
'n', 'stop', 'best_of', 'frequency_penalty', 'length_penalty',
@@ -114,6 +139,7 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
114139
kwargs[key] = getattr(llm_engine.generation_config, key)
115140
else:
116141
kwargs[key] = new_value
142+
117143
generation_config = VllmGenerationConfig(**kwargs)
118144
if generation_config.use_beam_search is True and request.stream is True:
119145
error_msg = 'Streaming generation does not support beam search.'
@@ -124,6 +150,10 @@ async def inference_vllm_async(request: Union[ChatCompletionRequest,
124150
if isinstance(template.suffix[-1],
125151
str) and template.suffix[-1] not in generation_config.stop:
126152
generation_config.stop.append(template.suffix[-1])
153+
request_info['generation_config'] = generation_config
154+
request_info.update({'seed': request.seed, 'stream': request.stream})
155+
logger.info(request_info)
156+
127157
created_time = int(time.time())
128158
result_generator = llm_engine.generate(None, generation_config, request_id,
129159
input_ids)
@@ -153,7 +183,7 @@ async def _generate_full():
153183
finish_reason=output.finish_reason,
154184
)
155185
choices.append(choice)
156-
return ChatCompletionResponse(
186+
response = ChatCompletionResponse(
157187
model=request.model,
158188
choices=choices,
159189
usage=usage_info,
@@ -168,12 +198,13 @@ async def _generate_full():
168198
finish_reason=output.finish_reason,
169199
)
170200
choices.append(choice)
171-
return CompletionResponse(
201+
response = CompletionResponse(
172202
model=request.model,
173203
choices=choices,
174204
usage=usage_info,
175205
id=request_id,
176206
created=created_time)
207+
return response
177208

178209
async def _generate_stream():
179210
print_idx_list = [0] * request.n
@@ -228,29 +259,221 @@ async def _generate_stream():
228259
return await _generate_full()
229260

230261

231-
def is_generation_template(template_type: str) -> bool:
232-
if 'generation' in template_type:
233-
return True
262+
class _GenerationConfig(GenerationConfig):
263+
264+
def __repr__(self) -> str:
265+
parameters = inspect.signature(self.to_json_string).parameters
266+
kwargs = {}
267+
if 'ignore_metadata' in parameters:
268+
kwargs['ignore_metadata'] = True
269+
gen_kwargs = json.loads(self.to_json_string(**kwargs))
270+
gen_kwargs.pop('transformers_version', None)
271+
return f'GenerationConfig({gen_kwargs})'
272+
273+
274+
async def inference_pt_async(request: Union[ChatCompletionRequest,
275+
CompletionRequest],
276+
raw_request: Request):
277+
global model, template
278+
error_msg = await check_model(request)
279+
if error_msg is not None:
280+
return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
281+
282+
if request.seed is not None:
283+
seed_everything(request.seed, verbose=False)
284+
_request = {'model': request.model}
285+
if isinstance(request, ChatCompletionRequest):
286+
if is_generation_template(template.template_type):
287+
return create_error_response(
288+
HTTPStatus.BAD_REQUEST,
289+
f'The chat template `{template.template_type}` corresponding to '
290+
f'the model `{model.model_type}` is in text generation format. '
291+
'Please use the `completions` API.')
292+
example = messages_to_history(request.messages)
293+
input_ids = template.encode(example)[0]['input_ids']
294+
request_id = f'chatcmpl-{random_uuid()}'
295+
_request['messages'] = request.messages
234296
else:
235-
return False
297+
if not is_generation_template(template.template_type):
298+
return create_error_response(
299+
HTTPStatus.BAD_REQUEST,
300+
f'The chat template `{template.template_type}` corresponding to '
301+
f'the model `{model.model_type}` is in chat format. '
302+
'Please use the `chat.completions` API.')
303+
example = {'query': request.prompt}
304+
input_ids = template.encode(example)[0]['input_ids']
305+
request_id = f'cmpl-{random_uuid()}'
306+
_request['prompt'] = request.prompt
307+
308+
request_info = {'request_id': request_id}
309+
request_info.update(_request)
310+
311+
error_msg = await check_length(request, input_ids)
312+
if error_msg is not None:
313+
return create_error_response(HTTPStatus.BAD_REQUEST, error_msg)
314+
315+
kwargs = {'max_new_tokens': request.max_tokens}
316+
# not use: 'n', 'best_of', 'frequency_penalty', 'presence_penalty'
317+
for key in ['length_penalty', 'num_beams']:
318+
kwargs[key] = getattr(request, key)
319+
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
320+
new_value = getattr(request, key)
321+
if new_value is None:
322+
kwargs[key] = getattr(model.generation_config, key)
323+
else:
324+
kwargs[key] = new_value
325+
if kwargs['temperature'] == 0:
326+
kwargs['do_sample'] = False
327+
kwargs['temperature'] = 1
328+
kwargs['top_p'] = 1
329+
kwargs['top_k'] = 50
330+
else:
331+
kwargs['do_sample'] = True
332+
333+
generation_config = _GenerationConfig(**kwargs)
334+
request_info['generation_config'] = generation_config
335+
request_info.update({
336+
'seed': request.seed,
337+
'stop': request.stop,
338+
'stream': request.stream
339+
})
340+
logger.info(request_info)
341+
342+
created_time = int(time.time())
343+
344+
async def _generate_full():
345+
generation_info = {}
346+
response, _ = inference(
347+
model,
348+
template,
349+
**example,
350+
stop_words=request.stop,
351+
generation_config=generation_config,
352+
generation_info=generation_info)
353+
num_prompt_tokens = generation_info['num_prompt_tokens']
354+
num_generated_tokens = generation_info['num_generated_tokens']
355+
usage_info = UsageInfo(
356+
prompt_tokens=num_prompt_tokens,
357+
completion_tokens=num_generated_tokens,
358+
total_tokens=num_prompt_tokens + num_generated_tokens,
359+
)
360+
if isinstance(request, ChatCompletionRequest):
361+
choices = [
362+
ChatCompletionResponseChoice(
363+
index=0,
364+
message=ChatMessage(role='assistant', content=response),
365+
finish_reason=None,
366+
)
367+
]
368+
response = ChatCompletionResponse(
369+
model=request.model,
370+
choices=choices,
371+
usage=usage_info,
372+
id=request_id,
373+
created=created_time)
374+
else:
375+
choices = [
376+
CompletionResponseChoice(
377+
index=0,
378+
text=response,
379+
finish_reason=None,
380+
)
381+
]
382+
response = CompletionResponse(
383+
model=request.model,
384+
choices=choices,
385+
usage=usage_info,
386+
id=request_id,
387+
created=created_time)
388+
return response
389+
390+
def _generate_stream():
391+
generation_info = {}
392+
gen = inference_stream(
393+
model,
394+
template,
395+
**example,
396+
stop_words=request.stop,
397+
generation_config=generation_config,
398+
generation_info=generation_info)
399+
400+
print_idx = 0
401+
for response, _ in gen:
402+
num_prompt_tokens = generation_info['num_prompt_tokens']
403+
num_generated_tokens = generation_info['num_generated_tokens']
404+
usage_info = UsageInfo(
405+
prompt_tokens=num_prompt_tokens,
406+
completion_tokens=num_generated_tokens,
407+
total_tokens=num_prompt_tokens + num_generated_tokens,
408+
)
409+
if isinstance(request, ChatCompletionRequest):
410+
delta_text = response[print_idx:]
411+
print_idx = len(response)
412+
choices = [
413+
ChatCompletionResponseStreamChoice(
414+
index=0,
415+
delta=DeltaMessage(
416+
role='assistant', content=delta_text),
417+
finish_reason=None)
418+
]
419+
resp = ChatCompletionStreamResponse(
420+
model=request.model,
421+
choices=choices,
422+
usage=usage_info,
423+
id=request_id,
424+
created=created_time)
425+
else:
426+
delta_text = response[print_idx:]
427+
print_idx = len(response)
428+
choices = [
429+
CompletionResponseStreamChoice(
430+
index=0, text=delta_text, finish_reason=None)
431+
]
432+
resp = CompletionStreamResponse(
433+
model=request.model,
434+
choices=choices,
435+
usage=usage_info,
436+
id=request_id,
437+
created=created_time)
438+
yield f'data:{json.dumps(asdict(resp), ensure_ascii=False)}\n\n'
439+
yield 'data:[DONE]\n\n'
440+
441+
if request.stream:
442+
return StreamingResponse(_generate_stream())
443+
else:
444+
return await _generate_full()
236445

237446

238447
@app.post('/v1/chat/completions')
239448
async def create_chat_completion(
240449
request: ChatCompletionRequest,
241450
raw_request: Request) -> ChatCompletionResponse:
242-
return await inference_vllm_async(request, raw_request)
451+
global _args
452+
assert _args is not None
453+
if _args.infer_backend == 'vllm':
454+
return await inference_vllm_async(request, raw_request)
455+
else:
456+
return await inference_pt_async(request, raw_request)
243457

244458

245459
@app.post('/v1/completions')
246460
async def create_completion(request: CompletionRequest,
247461
raw_request: Request) -> CompletionResponse:
248-
return await inference_vllm_async(request, raw_request)
462+
global _args
463+
assert _args is not None
464+
if _args.infer_backend == 'vllm':
465+
return await inference_vllm_async(request, raw_request)
466+
else:
467+
return await inference_pt_async(request, raw_request)
249468

250469

251470
def llm_deploy(args: DeployArguments) -> None:
471+
logger_format = logging.Formatter(
472+
'%(levelname)s: %(asctime)s %(filename)s:%(lineno)d] %(message)s')
473+
logger.handlers[0].setFormatter(logger_format)
252474
import uvicorn
253-
global llm_engine, model, template
475+
global llm_engine, model, template, _args
476+
_args = args
254477
if args.merge_lora:
255478
merge_lora(args, device_map='cpu')
256479
if args.infer_backend == 'vllm':

swift/llm/utils/argument.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,10 +585,11 @@ class DeployArguments(InferArguments):
585585
ssl_certfile: Optional[str] = None
586586

587587
def __post_init__(self):
588-
assert self.infer_backend != 'pt', 'The deployment only supports VLLM currently.'
589-
if self.infer_backend == 'AUTO':
590-
self.infer_backend = 'vllm'
591-
logger.info('Setting self.infer_backend: vllm')
588+
model_info = MODEL_MAPPING[self.model_type]
589+
tags = model_info.get('tags', [])
590+
if 'multi-modal' in tags:
591+
raise ValueError(
592+
'Deployment of multimodal models is currently not supported.')
592593
super().__post_init__()
593594

594595

0 commit comments

Comments
 (0)