12
12
import torch
13
13
from gradio import Accordion , Tab
14
14
15
- from swift .llm import SftArguments
15
+ from swift .llm import RLHFArguments , SftArguments
16
16
from swift .ui .base import BaseUI
17
17
from swift .ui .llm_train .advanced import Advanced
18
18
from swift .ui .llm_train .dataset import Dataset
23
23
from swift .ui .llm_train .lora import LoRA
24
24
from swift .ui .llm_train .model import Model
25
25
from swift .ui .llm_train .quantization import Quantization
26
+ from swift .ui .llm_train .rlhf import RLHF
26
27
from swift .ui .llm_train .runtime import Runtime
27
28
from swift .ui .llm_train .save import Save
28
29
from swift .ui .llm_train .self_cog import SelfCog
@@ -53,6 +54,7 @@ class LLMTrain(BaseUI):
53
54
Quantization ,
54
55
SelfCog ,
55
56
Advanced ,
57
+ RLHF ,
56
58
]
57
59
58
60
locale_dict : Dict [str , Dict ] = {
@@ -62,6 +64,16 @@ class LLMTrain(BaseUI):
62
64
'en' : 'LLM Training' ,
63
65
}
64
66
},
67
+ 'train_type' : {
68
+ 'label' : {
69
+ 'zh' : '训练Stage' ,
70
+ 'en' : 'Train Stage'
71
+ },
72
+ 'info' : {
73
+ 'zh' : '请注意选择于此匹配的数据集,人类对齐配置在页面下方' ,
74
+ 'en' : 'Please choose matched dataset, RLHF settings is at the bottom of the page'
75
+ }
76
+ },
65
77
'submit_alert' : {
66
78
'value' : {
67
79
'zh' :
@@ -185,9 +197,9 @@ class LLMTrain(BaseUI):
185
197
},
186
198
}
187
199
188
- choice_dict = BaseUI .get_choices_from_dataclass (SftArguments )
189
- default_dict = BaseUI .get_default_value_from_dataclass (SftArguments )
190
- arguments = BaseUI .get_argument_names (SftArguments )
200
+ choice_dict = BaseUI .get_choices_from_dataclass (RLHFArguments )
201
+ default_dict = BaseUI .get_default_value_from_dataclass (RLHFArguments )
202
+ arguments = BaseUI .get_argument_names (RLHFArguments )
191
203
192
204
@classmethod
193
205
def do_build_ui (cls , base_tab : Type ['BaseUI' ]):
@@ -200,17 +212,19 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
200
212
with gr .Blocks ():
201
213
Model .build_ui (base_tab )
202
214
Dataset .build_ui (base_tab )
203
- Hyper .build_ui (base_tab )
204
- Save .build_ui (base_tab )
205
- Runtime .build_ui (base_tab )
206
215
with gr .Row ():
207
- gr .Dropdown (elem_id = 'sft_type' , scale = 4 )
208
- gr .Dropdown (elem_id = 'tuner_backend' , scale = 4 )
209
- gr .Textbox (elem_id = 'sequence_parallel_size' , scale = 4 )
216
+ gr .Dropdown (elem_id = 'train_type' , choices = ['pretrain/sft' , 'rlhf' ], value = 'pretrain/sft' , scale = 3 )
217
+ gr .Dropdown (elem_id = 'sft_type' , scale = 2 )
218
+ gr .Dropdown (elem_id = 'tuner_backend' , scale = 2 )
219
+ gr .Textbox (elem_id = 'sequence_parallel_size' , scale = 3 )
220
+ with gr .Row ():
210
221
gr .Textbox (elem_id = 'seed' , scale = 4 )
211
222
gr .Dropdown (elem_id = 'dtype' , scale = 4 )
212
223
gr .Checkbox (elem_id = 'use_ddp' , value = False , scale = 4 )
213
224
gr .Textbox (elem_id = 'ddp_num' , value = '2' , scale = 4 )
225
+ Hyper .build_ui (base_tab )
226
+ Save .build_ui (base_tab )
227
+ Runtime .build_ui (base_tab )
214
228
with gr .Row ():
215
229
gr .Dropdown (
216
230
elem_id = 'gpu_id' ,
@@ -230,6 +244,7 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
230
244
Lisa .build_ui (base_tab )
231
245
LlamaPro .build_ui (base_tab )
232
246
Quantization .build_ui (base_tab )
247
+ RLHF .build_ui (base_tab )
233
248
SelfCog .build_ui (base_tab )
234
249
Advanced .build_ui (base_tab )
235
250
@@ -273,26 +288,24 @@ def update_runtime(cls):
273
288
274
289
@classmethod
275
290
def train (cls , * args ):
276
- ignore_elements = ('model_type' , 'logging_dir' , 'more_params' )
277
- sft_args = cls .get_default_value_from_dataclass (SftArguments )
291
+ ignore_elements = ('model_type' , 'logging_dir' , 'more_params' , 'train_type' )
292
+ sft_args = cls .get_default_value_from_dataclass (RLHFArguments )
278
293
kwargs = {}
279
294
kwargs_is_list = {}
280
295
other_kwargs = {}
281
296
more_params = {}
282
297
keys = [key for key , value in cls .elements ().items () if not isinstance (value , (Tab , Accordion ))]
283
298
model_type = None
299
+ do_rlhf = False
284
300
for key , value in zip (keys , args ):
285
301
compare_value = sft_args .get (key )
286
- compare_value_arg = str (compare_value ) if not isinstance (compare_value , (list , dict )) else compare_value
287
- compare_value_ui = str (value ) if not isinstance (value , (list , dict )) else value
288
-
289
302
if isinstance (value , str ) and re .fullmatch (cls .int_regex , value ):
290
303
value = int (value )
291
304
elif isinstance (value , str ) and re .fullmatch (cls .float_regex , value ):
292
305
value = float (value )
293
306
elif isinstance (value , str ) and re .fullmatch (cls .bool_regex , value ):
294
307
value = True if value .lower () == 'true' else False
295
- if key not in ignore_elements and key in sft_args and compare_value_ui != compare_value_arg and value :
308
+ if key not in ignore_elements and key in sft_args and compare_value != value and value :
296
309
kwargs [key ] = value if not isinstance (value , list ) else ' ' .join (value )
297
310
kwargs_is_list [key ] = isinstance (value , list ) or getattr (cls .element (key ), 'is_list' , False )
298
311
else :
@@ -303,14 +316,18 @@ def train(cls, *args):
303
316
if key == 'model_type' :
304
317
model_type = value
305
318
319
+ if key == 'train_type' :
320
+ do_rlhf = value == 'rlhf'
321
+
306
322
if os .path .exists (kwargs ['model_id_or_path' ]):
307
323
kwargs ['model_type' ] = model_type
308
324
309
325
kwargs .update (more_params )
310
326
if 'dataset' not in kwargs and 'custom_train_dataset_path' not in kwargs :
311
327
raise gr .Error (cls .locale ('dataset_alert' , cls .lang )['value' ])
312
328
313
- sft_args = SftArguments (
329
+ cmd = 'rlhf' if do_rlhf else 'sft'
330
+ sft_args = RLHFArguments (
314
331
** {
315
332
key : value .split (' ' ) if kwargs_is_list .get (key , False ) and isinstance (value , str ) else value
316
333
for key , value in kwargs .items ()
@@ -323,7 +340,7 @@ def train(cls, *args):
323
340
else :
324
341
params += f'--{ e } "{ kwargs [e ]} " '
325
342
params += f'--add_output_dir_suffix False --output_dir { sft_args .output_dir } ' \
326
- f'--logging_dir { sft_args .logging_dir } '
343
+ f'--logging_dir { sft_args .logging_dir } --ignore_args_error True '
327
344
ddp_param = ''
328
345
devices = other_kwargs ['gpu_id' ]
329
346
devices = [d for d in devices if d ]
@@ -344,9 +361,9 @@ def train(cls, *args):
344
361
ddp_param = f'set { ddp_param } && '
345
362
run_command = f'{ cuda_param } { ddp_param } start /b swift sft { params } > { log_file } 2>&1'
346
363
elif cls .is_studio :
347
- run_command = f'{ cuda_param } { ddp_param } swift sft { params } '
364
+ run_command = f'{ cuda_param } { ddp_param } swift { cmd } { params } '
348
365
else :
349
- run_command = f'{ cuda_param } { ddp_param } nohup swift sft { params } > { log_file } 2>&1 &'
366
+ run_command = f'{ cuda_param } { ddp_param } nohup swift { cmd } { params } > { log_file } 2>&1 &'
350
367
logger .info (f'Run training: { run_command } ' )
351
368
return run_command , sft_args , other_kwargs
352
369
0 commit comments