Skip to content

Commit a5f5000

Browse files
RLHF UI (modelscope#1182)
1 parent cc4eda1 commit a5f5000

File tree

7 files changed

+280
-47
lines changed

7 files changed

+280
-47
lines changed

swift/llm/utils/argument.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,8 +1407,8 @@ class RLHFArguments(SftArguments):
14071407
max_prompt_length: int = 1024
14081408
beta: Optional[float] = None
14091409
label_smoothing: float = 0.0
1410-
loss_type: Optional[Literal['sigmoid', 'hinge', 'ipo', 'kto_pair', 'robust', 'bco_pair', 'sppo_hard', 'nca_pair',
1411-
'simpo', 'kto', 'bco']] = None
1410+
loss_type: Literal['sigmoid', 'hinge', 'ipo', 'kto_pair', 'robust', 'bco_pair', 'sppo_hard', 'nca_pair', 'simpo',
1411+
'kto', 'bco'] = None
14121412
sft_beta: float = 0.1
14131413
simpo_gamma: float = 1.0 # reward margin hyperparameter in SimPO
14141414
# KTO

swift/ui/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def wrapper(*args, **kwargs):
5151
if argument and 'label' in kwargs:
5252
kwargs['label'] = kwargs['label'] + f'({argument})'
5353

54+
kwargs['elem_classes'] = 'align'
5455
ret = fn(self, **kwargs)
5556
self.constructor_args.update(kwargs)
5657

swift/ui/llm_train/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
113113
gr.Textbox(elem_id='custom_val_dataset_path', is_list=True, scale=20)
114114
with gr.Row():
115115
gr.Slider(elem_id='dataset_test_ratio', minimum=0.0, maximum=1.0, step=0.05, scale=20)
116-
gr.Slider(elem_id='max_length', minimum=32, maximum=8192, step=32, scale=20)
116+
gr.Slider(elem_id='max_length', minimum=32, maximum=32768, step=32, scale=20)
117117
gr.Textbox(elem_id='train_dataset_sample', scale=20)
118118
gr.Textbox(elem_id='val_dataset_sample', scale=20)
119119
gr.Dropdown(elem_id='truncation_strategy', scale=20)

swift/ui/llm_train/llm_train.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from gradio import Accordion, Tab
1414

15-
from swift.llm import SftArguments
15+
from swift.llm import RLHFArguments, SftArguments
1616
from swift.ui.base import BaseUI
1717
from swift.ui.llm_train.advanced import Advanced
1818
from swift.ui.llm_train.dataset import Dataset
@@ -23,6 +23,7 @@
2323
from swift.ui.llm_train.lora import LoRA
2424
from swift.ui.llm_train.model import Model
2525
from swift.ui.llm_train.quantization import Quantization
26+
from swift.ui.llm_train.rlhf import RLHF
2627
from swift.ui.llm_train.runtime import Runtime
2728
from swift.ui.llm_train.save import Save
2829
from swift.ui.llm_train.self_cog import SelfCog
@@ -53,6 +54,7 @@ class LLMTrain(BaseUI):
5354
Quantization,
5455
SelfCog,
5556
Advanced,
57+
RLHF,
5658
]
5759

5860
locale_dict: Dict[str, Dict] = {
@@ -62,6 +64,16 @@ class LLMTrain(BaseUI):
6264
'en': 'LLM Training',
6365
}
6466
},
67+
'train_type': {
68+
'label': {
69+
'zh': '训练Stage',
70+
'en': 'Train Stage'
71+
},
72+
'info': {
73+
'zh': '请注意选择于此匹配的数据集,人类对齐配置在页面下方',
74+
'en': 'Please choose matched dataset, RLHF settings is at the bottom of the page'
75+
}
76+
},
6577
'submit_alert': {
6678
'value': {
6779
'zh':
@@ -185,9 +197,9 @@ class LLMTrain(BaseUI):
185197
},
186198
}
187199

188-
choice_dict = BaseUI.get_choices_from_dataclass(SftArguments)
189-
default_dict = BaseUI.get_default_value_from_dataclass(SftArguments)
190-
arguments = BaseUI.get_argument_names(SftArguments)
200+
choice_dict = BaseUI.get_choices_from_dataclass(RLHFArguments)
201+
default_dict = BaseUI.get_default_value_from_dataclass(RLHFArguments)
202+
arguments = BaseUI.get_argument_names(RLHFArguments)
191203

192204
@classmethod
193205
def do_build_ui(cls, base_tab: Type['BaseUI']):
@@ -200,17 +212,19 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
200212
with gr.Blocks():
201213
Model.build_ui(base_tab)
202214
Dataset.build_ui(base_tab)
203-
Hyper.build_ui(base_tab)
204-
Save.build_ui(base_tab)
205-
Runtime.build_ui(base_tab)
206215
with gr.Row():
207-
gr.Dropdown(elem_id='sft_type', scale=4)
208-
gr.Dropdown(elem_id='tuner_backend', scale=4)
209-
gr.Textbox(elem_id='sequence_parallel_size', scale=4)
216+
gr.Dropdown(elem_id='train_type', choices=['pretrain/sft', 'rlhf'], value='pretrain/sft', scale=3)
217+
gr.Dropdown(elem_id='sft_type', scale=2)
218+
gr.Dropdown(elem_id='tuner_backend', scale=2)
219+
gr.Textbox(elem_id='sequence_parallel_size', scale=3)
220+
with gr.Row():
210221
gr.Textbox(elem_id='seed', scale=4)
211222
gr.Dropdown(elem_id='dtype', scale=4)
212223
gr.Checkbox(elem_id='use_ddp', value=False, scale=4)
213224
gr.Textbox(elem_id='ddp_num', value='2', scale=4)
225+
Hyper.build_ui(base_tab)
226+
Save.build_ui(base_tab)
227+
Runtime.build_ui(base_tab)
214228
with gr.Row():
215229
gr.Dropdown(
216230
elem_id='gpu_id',
@@ -230,6 +244,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
230244
Lisa.build_ui(base_tab)
231245
LlamaPro.build_ui(base_tab)
232246
Quantization.build_ui(base_tab)
247+
RLHF.build_ui(base_tab)
233248
SelfCog.build_ui(base_tab)
234249
Advanced.build_ui(base_tab)
235250

@@ -273,26 +288,24 @@ def update_runtime(cls):
273288

274289
@classmethod
275290
def train(cls, *args):
276-
ignore_elements = ('model_type', 'logging_dir', 'more_params')
277-
sft_args = cls.get_default_value_from_dataclass(SftArguments)
291+
ignore_elements = ('model_type', 'logging_dir', 'more_params', 'train_type')
292+
sft_args = cls.get_default_value_from_dataclass(RLHFArguments)
278293
kwargs = {}
279294
kwargs_is_list = {}
280295
other_kwargs = {}
281296
more_params = {}
282297
keys = [key for key, value in cls.elements().items() if not isinstance(value, (Tab, Accordion))]
283298
model_type = None
299+
do_rlhf = False
284300
for key, value in zip(keys, args):
285301
compare_value = sft_args.get(key)
286-
compare_value_arg = str(compare_value) if not isinstance(compare_value, (list, dict)) else compare_value
287-
compare_value_ui = str(value) if not isinstance(value, (list, dict)) else value
288-
289302
if isinstance(value, str) and re.fullmatch(cls.int_regex, value):
290303
value = int(value)
291304
elif isinstance(value, str) and re.fullmatch(cls.float_regex, value):
292305
value = float(value)
293306
elif isinstance(value, str) and re.fullmatch(cls.bool_regex, value):
294307
value = True if value.lower() == 'true' else False
295-
if key not in ignore_elements and key in sft_args and compare_value_ui != compare_value_arg and value:
308+
if key not in ignore_elements and key in sft_args and compare_value != value and value:
296309
kwargs[key] = value if not isinstance(value, list) else ' '.join(value)
297310
kwargs_is_list[key] = isinstance(value, list) or getattr(cls.element(key), 'is_list', False)
298311
else:
@@ -303,14 +316,18 @@ def train(cls, *args):
303316
if key == 'model_type':
304317
model_type = value
305318

319+
if key == 'train_type':
320+
do_rlhf = value == 'rlhf'
321+
306322
if os.path.exists(kwargs['model_id_or_path']):
307323
kwargs['model_type'] = model_type
308324

309325
kwargs.update(more_params)
310326
if 'dataset' not in kwargs and 'custom_train_dataset_path' not in kwargs:
311327
raise gr.Error(cls.locale('dataset_alert', cls.lang)['value'])
312328

313-
sft_args = SftArguments(
329+
cmd = 'rlhf' if do_rlhf else 'sft'
330+
sft_args = RLHFArguments(
314331
**{
315332
key: value.split(' ') if kwargs_is_list.get(key, False) and isinstance(value, str) else value
316333
for key, value in kwargs.items()
@@ -323,7 +340,7 @@ def train(cls, *args):
323340
else:
324341
params += f'--{e} "{kwargs[e]}" '
325342
params += f'--add_output_dir_suffix False --output_dir {sft_args.output_dir} ' \
326-
f'--logging_dir {sft_args.logging_dir} '
343+
f'--logging_dir {sft_args.logging_dir} --ignore_args_error True'
327344
ddp_param = ''
328345
devices = other_kwargs['gpu_id']
329346
devices = [d for d in devices if d]
@@ -344,9 +361,9 @@ def train(cls, *args):
344361
ddp_param = f'set {ddp_param} && '
345362
run_command = f'{cuda_param}{ddp_param}start /b swift sft {params} > {log_file} 2>&1'
346363
elif cls.is_studio:
347-
run_command = f'{cuda_param} {ddp_param} swift sft {params}'
364+
run_command = f'{cuda_param} {ddp_param} swift {cmd} {params}'
348365
else:
349-
run_command = f'{cuda_param} {ddp_param} nohup swift sft {params} > {log_file} 2>&1 &'
366+
run_command = f'{cuda_param} {ddp_param} nohup swift {cmd} {params} > {log_file} 2>&1 &'
350367
logger.info(f'Run training: {run_command}')
351368
return run_command, sft_args, other_kwargs
352369

swift/ui/llm_train/lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
8080
with gr.Accordion(elem_id='lora_tab', open=True):
8181
with gr.Blocks():
8282
with gr.Row():
83-
lora_target_modules = gr.Textbox(elem_id='lora_target_modules', lines=1, scale=20, is_list=True)
83+
lora_target_modules = gr.Textbox(elem_id='lora_target_modules', lines=1, scale=5, is_list=True)
84+
gr.Slider(elem_id='lora_rank', value=32, minimum=1, maximum=512, step=8, scale=2)
85+
gr.Slider(elem_id='lora_alpha', value=8, minimum=1, maximum=512, step=8, scale=2)
8486
with gr.Row():
85-
gr.Slider(elem_id='lora_rank', value=32, minimum=1, maximum=512, step=8)
86-
gr.Slider(elem_id='lora_alpha', value=8, minimum=1, maximum=512, step=8)
8787
gr.Dropdown(elem_id='lora_dtype')
8888
gr.Textbox(elem_id='lora_lr_ratio')
8989
gr.Checkbox(elem_id='use_rslora')

swift/ui/llm_train/rlhf.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from typing import Type
2+
3+
import gradio as gr
4+
5+
from swift.llm import MODEL_MAPPING, ModelType
6+
from swift.ui.base import BaseUI
7+
8+
9+
class RLHF(BaseUI):
10+
11+
group = 'llm_train'
12+
13+
locale_dict = {
14+
'rlhf_tab': {
15+
'label': {
16+
'zh': '人类对齐参数设置',
17+
'en': 'RLHF settings'
18+
},
19+
},
20+
'rlhf_type': {
21+
'label': {
22+
'zh': '人类对齐算法类型',
23+
'en': 'RLHF type'
24+
},
25+
},
26+
'ref_model_type': {
27+
'label': {
28+
'zh': '选择ref模型',
29+
'en': 'Select ref model'
30+
},
31+
'info': {
32+
'zh': 'SWIFT已支持的模型名称',
33+
'en': 'Base model supported by SWIFT'
34+
}
35+
},
36+
'ref_model_id_or_path': {
37+
'label': {
38+
'zh': 'ref模型id或路径',
39+
'en': 'Ref model id or path'
40+
},
41+
'info': {
42+
'zh': '实际的模型id或路径',
43+
'en': 'The actual model id or path'
44+
}
45+
},
46+
'max_prompt_length': {
47+
'label': {
48+
'zh': 'prompt最大token长度',
49+
'en': 'Max prompt length'
50+
},
51+
},
52+
'beta': {
53+
'label': {
54+
'zh': 'KL正则项系数',
55+
'en': 'KL regression ratio'
56+
},
57+
},
58+
'loss_type': {
59+
'label': {
60+
'zh': 'Loss类型',
61+
'en': 'Loss type'
62+
},
63+
},
64+
'sft_beta': {
65+
'label': {
66+
'zh': 'DPO中混合sft交叉熵的系数',
67+
'en': 'DPO Cross Entropy ratio'
68+
},
69+
},
70+
'simpo_gamma': {
71+
'label': {
72+
'zh': 'SimPO reward margin',
73+
'en': 'SimPO reward margin'
74+
},
75+
},
76+
'desirable_weight': {
77+
'label': {
78+
'zh': 'KTO符合项系数',
79+
'en': 'KTO desirable ratio'
80+
},
81+
},
82+
'undesirable_weight': {
83+
'label': {
84+
'zh': 'KTO不符合项系数',
85+
'en': 'KTO undesirable ratio'
86+
},
87+
}
88+
}
89+
90+
@classmethod
91+
def do_build_ui(cls, base_tab: Type['BaseUI']):
92+
with gr.Accordion(elem_id='rlhf_tab', open=False):
93+
with gr.Blocks():
94+
with gr.Row():
95+
rlhf_type = gr.Dropdown(elem_id='rlhf_type')
96+
ref_model_type = gr.Dropdown(
97+
elem_id='ref_model_type',
98+
choices=ModelType.get_model_name_list() + cls.get_custom_name_list(),
99+
scale=20)
100+
ref_model_id_or_path = gr.Textbox(elem_id='ref_model_id_or_path', lines=1, scale=20)
101+
model_state = gr.State({})
102+
with gr.Row():
103+
loss_type = gr.Dropdown(elem_id='loss_type')
104+
gr.Textbox(elem_id='max_prompt_length', lines=1, scale=20)
105+
beta = gr.Slider(elem_id='beta', minimum=0., maximum=5.0, step=0.1, scale=20)
106+
gr.Slider(elem_id='sft_beta', minimum=0., maximum=0.95, step=0.05, scale=20)
107+
gr.Slider(elem_id='simpo_gamma', minimum=0., maximum=2.0, step=0.1, scale=20)
108+
gr.Slider(elem_id='desirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
109+
gr.Slider(elem_id='undesirable_weight', minimum=0., maximum=2.0, step=0.1, scale=20)
110+
111+
def update_input_model(choice, model_state=None):
112+
if choice is None:
113+
return None
114+
if model_state and choice in model_state:
115+
model_id_or_path = model_state[choice]
116+
else:
117+
model_id_or_path = MODEL_MAPPING[choice]['model_id_or_path']
118+
return model_id_or_path
119+
120+
def update_model_id_or_path(model_type, model_id_or_path, model_state):
121+
if model_type is None or isinstance(model_type, list):
122+
return model_state
123+
model_state[model_type] = model_id_or_path
124+
return model_state
125+
126+
def update_value(rlhf_type):
127+
beta = None
128+
if rlhf_type in ['dpo', 'orpo', 'kto', 'cpo']:
129+
beta = 0.1
130+
elif rlhf_type == 'simpo':
131+
beta = 2.0
132+
133+
loss_type = None
134+
if rlhf_type in ['dpo', 'cpo']:
135+
loss_type = 'sigmoid'
136+
elif rlhf_type == 'kto':
137+
loss_type = 'kto'
138+
139+
return beta, loss_type
140+
141+
rlhf_type.change(update_value, inputs=[rlhf_type], outputs=[beta, loss_type])
142+
143+
ref_model_type.change(
144+
update_input_model, inputs=[ref_model_type, model_state], outputs=[ref_model_id_or_path])
145+
146+
ref_model_id_or_path.change(
147+
update_model_id_or_path,
148+
inputs=[ref_model_type, ref_model_id_or_path, model_state],
149+
outputs=[model_state])

0 commit comments

Comments
 (0)