Skip to content

Commit 3f1c753

Browse files
authored
support torchrun_args for dpo cli and support web_ui model deployment (#496)
1 parent db24a6f commit 3f1c753

File tree

3 files changed

+318
-60
lines changed

3 files changed

+318
-60
lines changed

swift/cli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def cli_main() -> None:
5050
argv = argv[1:]
5151
file_path = importlib.util.find_spec(ROUTE_MAPPING[method_name]).origin
5252
torchrun_args = get_torchrun_args()
53-
if torchrun_args is None or method_name != 'sft':
53+
if torchrun_args is None or method_name not in ('sft', 'dpo'):
5454
args = ['python', file_path, *argv]
5555
else:
5656
args = ['torchrun', *torchrun_args, file_path, *argv]

swift/ui/llm_infer/llm_infer.py

Lines changed: 141 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import collections
12
import os
23
import re
4+
import sys
5+
from subprocess import PIPE, STDOUT, Popen
36
from typing import Type
47

58
import gradio as gr
@@ -8,17 +11,20 @@
811
from gradio import Accordion, Tab
912

1013
from swift import snapshot_download
11-
from swift.llm import (InferArguments, inference_stream, limit_history_length,
12-
prepare_model_template)
14+
from swift.llm import (DeployArguments, InferArguments, XRequestConfig,
15+
inference_client)
1316
from swift.ui.base import BaseUI
1417
from swift.ui.llm_infer.model import Model
18+
from swift.ui.llm_infer.runtime import Runtime
19+
from swift.utils import get_logger
1520

21+
logger = get_logger()
1622

17-
class LLMInfer(BaseUI):
1823

24+
class LLMInfer(BaseUI):
1925
group = 'llm_infer'
2026

21-
sub_ui = [Model]
27+
sub_ui = [Model, Runtime]
2228

2329
locale_dict = {
2430
'generate_alert': {
@@ -92,8 +98,9 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
9298
gpu_count = torch.cuda.device_count()
9399
default_device = '0'
94100
with gr.Blocks():
95-
model_and_template = gr.State([])
101+
model_and_template_type = gr.State([])
96102
Model.build_ui(base_tab)
103+
Runtime.build_ui(base_tab)
97104
gr.Dropdown(
98105
elem_id='gpu_id',
99106
multiselect=True,
@@ -112,7 +119,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
112119
submit.click(
113120
cls.generate_chat,
114121
inputs=[
115-
model_and_template,
122+
model_and_template_type,
116123
cls.element('template_type'), prompt, chatbot,
117124
cls.element('max_new_tokens'),
118125
cls.element('system')
@@ -121,18 +128,40 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
121128
queue=True)
122129
clear_history.click(
123130
fn=cls.clear_session, inputs=[], outputs=[prompt, chatbot])
124-
cls.element('load_checkpoint').click(
125-
cls.reset_memory, [], [model_and_template])\
126-
.then(cls.reset_loading_button, [], [cls.element('load_checkpoint')]).then(
127-
cls.prepare_checkpoint, [
128-
value for value in cls.elements().values()
129-
if not isinstance(value, (Tab, Accordion))
130-
], [model_and_template]).then(cls.change_interactive, [],
131-
[prompt]).then( # noqa
132-
cls.clear_session,
133-
inputs=[],
134-
outputs=[prompt, chatbot],
135-
queue=True).then(cls.reset_load_button, [], [cls.element('load_checkpoint')])
131+
132+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
133+
cls.element('load_checkpoint').click(
134+
cls.update_runtime, [],
135+
[cls.element('runtime_tab'),
136+
cls.element('log')]).then(
137+
cls.deploy_studio, [
138+
value for value in cls.elements().values()
139+
if not isinstance(value, (Tab, Accordion))
140+
], [cls.element('log')],
141+
queue=True)
142+
else:
143+
cls.element('load_checkpoint').click(
144+
cls.reset_memory, [], [model_and_template_type]).then(
145+
cls.reset_loading_button, [],
146+
[cls.element('load_checkpoint')
147+
]).then(cls.get_model_template_type, [
148+
value for value in cls.elements().values()
149+
if not isinstance(value, (Tab, Accordion))
150+
], [model_and_template_type]).then(
151+
cls.deploy_local, [
152+
value
153+
for value in cls.elements().values()
154+
if not isinstance(value, (Tab, Accordion))
155+
], []).then(
156+
cls.change_interactive, [],
157+
[prompt]).then( # noqa
158+
cls.clear_session,
159+
inputs=[],
160+
outputs=[prompt,
161+
chatbot],
162+
queue=True).then(
163+
cls.reset_load_button, [],
164+
[cls.element('load_checkpoint')])
136165

137166
@classmethod
138167
def reset_load_button(cls):
@@ -148,9 +177,46 @@ def reset_memory(cls):
148177
return []
149178

150179
@classmethod
151-
def prepare_checkpoint(cls, *args):
152-
torch.cuda.empty_cache()
153-
infer_args = cls.get_default_value_from_dataclass(InferArguments)
180+
def clear_session(cls):
181+
return '', None
182+
183+
@classmethod
184+
def change_interactive(cls):
185+
return gr.update(interactive=True)
186+
187+
@classmethod
188+
def generate_chat(cls,
189+
model_and_template_type,
190+
template_type,
191+
prompt: str,
192+
history,
193+
max_new_tokens,
194+
system,
195+
seed=42):
196+
model_type = model_and_template_type[0]
197+
old_history, history = history, []
198+
request_config = XRequestConfig(seed=seed)
199+
request_config.stream = True
200+
stream_resp_with_history = ''
201+
if not template_type.endswith('generation'):
202+
stream_resp = inference_client(
203+
model_type,
204+
prompt,
205+
old_history,
206+
system=system,
207+
request_config=request_config)
208+
else:
209+
stream_resp = inference_client(
210+
model_type, prompt, request_config=request_config)
211+
for chunk in stream_resp:
212+
stream_resp_with_history += chunk.choices[0].delta.content
213+
qr_pair = [prompt, stream_resp_with_history]
214+
total_history = old_history + [qr_pair]
215+
yield '', total_history
216+
217+
@classmethod
218+
def deploy(cls, *args):
219+
deploy_args = cls.get_default_value_from_dataclass(DeployArguments)
154220
kwargs = {}
155221
kwargs_is_list = {}
156222
other_kwargs = {}
@@ -160,12 +226,12 @@ def prepare_checkpoint(cls, *args):
160226
if not isinstance(value, (Tab, Accordion))
161227
]
162228
for key, value in zip(keys, args):
163-
compare_value = infer_args.get(key)
229+
compare_value = deploy_args.get(key)
164230
compare_value_arg = str(compare_value) if not isinstance(
165231
compare_value, (list, dict)) else compare_value
166232
compare_value_ui = str(value) if not isinstance(
167233
value, (list, dict)) else value
168-
if key in infer_args and compare_value_ui != compare_value_arg and value:
234+
if key in deploy_args and compare_value_ui != compare_value_arg and value:
169235
if isinstance(value, str) and re.fullmatch(
170236
cls.int_regex, value):
171237
value = int(value)
@@ -190,50 +256,66 @@ def prepare_checkpoint(cls, *args):
190256
'model_id_or_path' in kwargs
191257
and not os.path.exists(kwargs['model_id_or_path'])):
192258
kwargs.pop('model_type', None)
193-
259+
deploy_args = DeployArguments(
260+
**{
261+
key: value.split(' ')
262+
if key in kwargs_is_list and kwargs_is_list[key] else value
263+
for key, value in kwargs.items()
264+
})
265+
params = ''
266+
for e in kwargs:
267+
if e in kwargs_is_list and kwargs_is_list[e]:
268+
params += f'--{e} {kwargs[e]} '
269+
else:
270+
params += f'--{e} "{kwargs[e]}" '
194271
devices = other_kwargs['gpu_id']
195272
devices = [d for d in devices if d]
196273
assert (len(devices) == 1 or 'cpu' not in devices)
197274
gpus = ','.join(devices)
275+
cuda_param = ''
198276
if gpus != 'cpu':
199-
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
200-
infer_args = InferArguments(**kwargs)
201-
model, template = prepare_model_template(infer_args)
202-
gr.Info(cls.locale('loaded_alert', cls.lang)['value'])
203-
return [model, template]
277+
cuda_param = f'CUDA_VISIBLE_DEVICES={gpus}'
278+
279+
log_file = os.path.join(os.getcwd(), 'run_deploy.log')
280+
if sys.platform == 'win32':
281+
if cuda_param:
282+
cuda_param = f'set {cuda_param} && '
283+
run_command = f'{cuda_param}start /b swift deploy {params} > {log_file} 2>&1'
284+
elif os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
285+
run_command = f'{cuda_param} swift deploy {params}'
286+
else:
287+
run_command = f'{cuda_param} nohup swift deploy {params} > {log_file} 2>&1 &'
288+
return run_command, deploy_args
204289

205290
@classmethod
206-
def clear_session(cls):
207-
return '', None
291+
def deploy_studio(cls, *args):
292+
run_command, deploy_args = cls.deploy(*args)
293+
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
294+
lines = collections.deque(
295+
maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
296+
logger.info(f'Run deploying: {run_command}')
297+
process = Popen(
298+
run_command, shell=True, stdout=PIPE, stderr=STDOUT)
299+
with process.stdout:
300+
for line in iter(process.stdout.readline, b''):
301+
line = line.decode('utf-8')
302+
lines.append(line)
303+
yield '\n'.join(lines)
208304

209305
@classmethod
210-
def change_interactive(cls):
211-
return gr.update(interactive=True)
306+
def deploy_local(cls, *args):
307+
run_command, deploy_args = cls.deploy(*args)
308+
lines = collections.deque(
309+
maxlen=int(os.environ.get('MAX_LOG_LINES', 50)))
310+
logger.info(f'Run deploying: {run_command}')
311+
process = Popen(run_command, shell=True, stdout=PIPE, stderr=STDOUT)
312+
with process.stdout:
313+
for line in iter(process.stdout.readline, b''):
314+
line = line.decode('utf-8')
315+
lines.append(line)
316+
yield '\n'.join(lines)
212317

213318
@classmethod
214-
def generate_chat(cls, model_and_template, template_type, prompt: str,
215-
history, max_new_tokens, system):
216-
if not model_and_template:
217-
gr.Warning(cls.locale('generate_alert', cls.lang)['value'])
218-
return '', None
219-
model, template = model_and_template
220-
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
221-
model.cuda()
222-
if not template_type.endswith('generation'):
223-
old_history, history = limit_history_length(
224-
template, prompt, history, int(max_new_tokens))
225-
else:
226-
old_history = []
227-
history = []
228-
gen = inference_stream(
229-
model,
230-
template,
231-
prompt,
232-
history,
233-
system=system,
234-
stop_words=['Observation:'])
235-
for _, history in gen:
236-
total_history = old_history + history
237-
yield '', total_history
238-
if os.environ.get('MODELSCOPE_ENVIRONMENT') == 'studio':
239-
model.cpu()
319+
def get_model_template_type(cls, *args):
320+
run_command, deploy_args = cls.deploy(*args)
321+
return [deploy_args.model_type, deploy_args.template_type]

0 commit comments

Comments
 (0)