10
10
import torch
11
11
import torch .distributed as dist
12
12
import transformers
13
+ from datasets import Dataset as HfDataset
13
14
from datasets import concatenate_datasets
14
15
from packaging import version
15
16
from torch import dtype as Dtype
@@ -76,7 +77,7 @@ class SftArguments:
76
77
dataset_seed : int = 42
77
78
dataset_test_ratio : float = 0.01
78
79
train_dataset_sample : int = 20000 # -1: all dataset
79
- train_dataset_mix_ratio : Optional [ float ] = None
80
+ train_dataset_mix_ratio : float = 0.
80
81
train_dataset_mix_ds : List [str ] = field (
81
82
default_factory = lambda : ['ms-bench' ])
82
83
val_dataset_sample : Optional [int ] = None # -1: all dataset
@@ -165,7 +166,7 @@ class SftArguments:
165
166
adam_beta1 : float = 0.9
166
167
adam_beta2 : float = 0.999
167
168
learning_rate : Optional [float ] = None
168
- weight_decay : float = 0.01
169
+ weight_decay : float = 0.1
169
170
gradient_accumulation_steps : Optional [int ] = None
170
171
max_grad_norm : float = 0.5
171
172
predict_with_generate : bool = False
@@ -286,6 +287,8 @@ def __post_init__(self) -> None:
286
287
set_model_type (self )
287
288
if isinstance (self .dataset , str ):
288
289
self .dataset = [self .dataset ]
290
+ if isinstance (self .train_dataset_mix_ds , str ):
291
+ self .train_dataset_mix_ds = [self .train_dataset_mix_ds ]
289
292
register_custom_dataset (self )
290
293
check_flash_attn (self )
291
294
handle_generation_config (self )
@@ -653,6 +656,8 @@ def __post_init__(self) -> None:
653
656
else :
654
657
assert self .load_dataset_config is False , 'You need to first set `--load_args_from_ckpt_dir true`.'
655
658
set_model_type (self )
659
+ if isinstance (self .dataset , str ):
660
+ self .dataset = [self .dataset ]
656
661
register_custom_dataset (self )
657
662
check_flash_attn (self )
658
663
handle_generation_config (self )
@@ -661,8 +666,6 @@ def __post_init__(self) -> None:
661
666
if self .template_type == 'AUTO' :
662
667
self .template_type = get_default_template_type (self .model_type )
663
668
logger .info (f'Setting template_type: { self .template_type } ' )
664
- if isinstance (self .dataset , str ):
665
- self .dataset = [self .dataset ]
666
669
has_dataset = (
667
670
len (self .dataset ) > 0 or len (self .custom_train_dataset_path ) > 0
668
671
or len (self .custom_val_dataset_path ) > 0 )
@@ -1078,25 +1081,38 @@ def handle_path(args: Union[SftArguments, InferArguments]) -> None:
1078
1081
setattr (args , k , value )
1079
1082
1080
1083
1084
+ def _register_local_dataset (dataset_name : str , train_dataset_path : List [str ],
1085
+ val_dataset_path : List [str ]) -> None :
1086
+ register_dataset (
1087
+ dataset_name ,
1088
+ '_' ,
1089
+ train_dataset_path ,
1090
+ val_dataset_path ,
1091
+ get_function = get_custom_dataset ,
1092
+ exists_ok = True )
1093
+
1094
+
1081
1095
def register_custom_dataset (args : Union [SftArguments , InferArguments ]) -> None :
1096
+ dataset = []
1097
+ for d in args .dataset :
1098
+ if os .path .exists (d ):
1099
+ args .custom_train_dataset_path .append (d )
1100
+ else :
1101
+ dataset .append (d )
1102
+ args .dataset = dataset
1103
+
1082
1104
for key in ['custom_train_dataset_path' , 'custom_val_dataset_path' ]:
1083
1105
value = getattr (args , key )
1084
1106
if isinstance (value , str ):
1085
1107
setattr (args , key , [value ])
1086
1108
if len (args .custom_train_dataset_path ) == 0 and len (
1087
1109
args .custom_val_dataset_path ) == 0 :
1088
1110
return
1089
- register_dataset (
1090
- '_custom_dataset' ,
1091
- '_custom_dataset' ,
1092
- args .custom_train_dataset_path ,
1093
- args .custom_val_dataset_path ,
1094
- get_function = get_custom_dataset ,
1095
- exists_ok = True )
1096
- if args .dataset is None :
1097
- args .dataset = ['_custom_dataset' ]
1098
- elif '_custom_dataset' not in args .dataset :
1099
- args .dataset .append ('_custom_dataset' )
1111
+
1112
+ dataset_name = '_custom_dataset'
1113
+ _register_local_dataset (dataset_name , args .custom_train_dataset_path ,
1114
+ args .custom_val_dataset_path )
1115
+ args .dataset .append (dataset_name )
1100
1116
1101
1117
1102
1118
def load_from_ckpt_dir (args : InferArguments ) -> None :
@@ -1147,34 +1163,48 @@ def handle_generation_config(
1147
1163
)
1148
1164
1149
1165
1150
- def handle_dataset_mixture (args : SftArguments , train_dataset ,
1151
- mix_dataset_sample ) -> None :
1166
+ def handle_dataset_mixture (args : SftArguments ,
1167
+ train_dataset : HfDataset ) -> None :
1152
1168
if train_dataset is None :
1153
1169
return train_dataset
1154
- train_length = len (train_dataset )
1170
+ if args .train_dataset_mix_ratio <= 0 or len (
1171
+ args .train_dataset_mix_ds ) == 0 :
1172
+ return train_dataset
1173
+
1155
1174
random_state = np .random .RandomState (args .dataset_seed )
1156
- if mix_dataset_sample :
1157
- assert args .train_dataset_mix_ds is not None
1158
- train_dataset_mix_ds = [args .train_dataset_mix_ds ] if isinstance (
1159
- args .train_dataset_mix_ds , str ) else args .train_dataset_mix_ds
1160
- mixed_dataset = get_dataset (
1161
- train_dataset_mix_ds ,
1162
- 0.0 ,
1163
- random_state ,
1164
- check_dataset_strategy = args .check_dataset_strategy )[0 ]
1165
- if len (mixed_dataset ) < mix_dataset_sample :
1166
- logger .warn (
1167
- f'The length of dataset used for mixin: { train_dataset_mix_ds } are '
1168
- 'lesser than the ratio required by the `train_dataset_mix_ratio` '
1169
- f'argument:{ args .train_dataset_mix_ratio } '
1170
- f'the actual ratio is : { len (mixed_dataset )/ float (train_length )} '
1171
- )
1175
+ train_dataset_mix_ds = []
1176
+ custom_mix_ds = []
1177
+ for mix_ds in args .train_dataset_mix_ds :
1178
+ if os .path .exists (mix_ds ):
1179
+ custom_mix_ds .append (mix_ds )
1172
1180
else :
1173
- train_idxs = random_state .permutation (mix_dataset_sample )
1174
- mixed_dataset = mixed_dataset .select (train_idxs )
1175
- return concatenate_datasets ([train_dataset , mixed_dataset ])
1181
+ train_dataset_mix_ds .append (mix_ds )
1182
+
1183
+ if len (custom_mix_ds ) > 0 :
1184
+ dataset_name = '_custom_mixture'
1185
+ _register_local_dataset (dataset_name , custom_mix_ds , [])
1186
+ train_dataset_mix_ds .append (dataset_name )
1187
+ mix_dataset_sample = len (train_dataset ) * args .train_dataset_mix_ratio
1188
+ logger .info (f'train_dataset_mix_ds: { train_dataset_mix_ds } ' )
1189
+ logger .info (
1190
+ f'len(train_dataset): { len (train_dataset )} , mix_dataset_sample: { mix_dataset_sample } '
1191
+ )
1192
+ mixed_dataset = get_dataset (
1193
+ train_dataset_mix_ds ,
1194
+ 0.0 ,
1195
+ random_state ,
1196
+ check_dataset_strategy = args .check_dataset_strategy )[0 ]
1197
+ if len (mixed_dataset ) < mix_dataset_sample :
1198
+ logger .warn (
1199
+ f'The length of dataset used for mixin: { train_dataset_mix_ds } are '
1200
+ 'lesser than the ratio required by the `train_dataset_mix_ratio` '
1201
+ f'argument: { args .train_dataset_mix_ratio } . '
1202
+ f'the actual ratio is: { len (mixed_dataset )/ len (train_dataset ):.6} .'
1203
+ )
1176
1204
else :
1177
- return train_dataset
1205
+ train_idxs = random_state .permutation (mix_dataset_sample )
1206
+ mixed_dataset = mixed_dataset .select (train_idxs )
1207
+ return concatenate_datasets ([train_dataset , mixed_dataset ])
1178
1208
1179
1209
1180
1210
def swift_to_peft_format (lora_checkpoint_path : str ) -> str :
0 commit comments