Skip to content

Commit 7e1834e

Browse files
authored
support train_dataset_mix_ds using custom_local_path (#582)
1 parent fc804dc commit 7e1834e

File tree

4 files changed

+80
-53
lines changed

4 files changed

+80
-53
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@
5858
- `--deepspeed`: 用于指定deepspeed的配置文件的路径或者直接传入json格式的配置信息, 默认为`None`, 即不开启deepspeed. deepspeed可以节约显存. 我们书写了默认的[ZeRO-2配置文件](https://github.com/modelscope/swift/blob/main/swift/llm/ds_config/zero2.json), [ZeRO-3配置文件](https://github.com/modelscope/swift/blob/main/swift/llm/ds_config/zero3.json). 你只需要指定'default-zero2', 就会使用默认zero2配置文件; 指定'default-zero3', 就会使用默认的zero3配置文件.
5959
- `--batch_size`: 训练时的batch_size, 默认为`1`. 增大batch_size可以增加GPU的利用率, 但不一定会增加训练速度, 因为在一个batch中, 需要对较短的句子按该batch中最长句子的长度进行padding, 从而引入无效的计算量.
6060
- `--eval_batch_size`: 评估时的batch_size, 默认为`None`, 即当`predict_with_generate`为True时, 设置为1, 为False时, 设置为`batch_size`.
61-
- `--num_train_epochs`: 训练的epoch数, 默认为`1`. 如果`max_steps >= 0`, 则覆盖`num_train_epochs`. 通常情况下设置为3 ~ 5.
61+
- `--num_train_epochs`: 训练的epoch数, 默认为`1`. 如果`max_steps >= 0`, 则覆盖`num_train_epochs`. 你可以设置为3, 5, 10等.
6262
- `--max_steps`: 训练的max_steps数, 默认为`-1`. 如果`max_steps >= 0`, 则覆盖`num_train_epochs`.
6363
- `--optim`: 默认为`'adamw_torch'`.
6464
- `--learning_rate`: 默认值为`None`, 即如果`sft_type`为lora, 则设置为1e-4, 如果`sft_type`为full, 则设置为1e-5.
65-
- `--weight_decay`: 默认值为`0.01`. 推荐使用`0.1`或者`0.01`.
65+
- `--weight_decay`: 默认值为`0.1`.
6666
- `--gradient_accumulation_steps`: 梯度累加, 默认值为`None`, 设置为`math.ceil(16 / self.batch_size / world_size)`. `total_batch_size = batch_size * gradient_accumulation_steps * world_size`.
6767
- `--max_grad_norm`: 梯度裁剪, 默认值为`0.5`.
6868
- `--predict_with_generate`: 评估时是否使用生成式的方式, 默认为`False`. 如果设置为False, 则使用`loss`进行评估. 如果设置为True, 则使用`ROUGE-L`等指标进行评估. 使用生成式评估耗费的时间很长, 请谨慎选择.
@@ -100,8 +100,8 @@
100100
- `--repetition_penalty`: 默认为`1.`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
101101
- `--num_beams`: 默认为`1`. 该参数只有在`predict_with_generate`设置为True的时候才生效.
102102
- `--gpu_memory_fraction`: 默认为`None`. 该参数旨在指定显卡最大可用显存比例的情况下运行训练,用于极限测试.
103-
- `--train_dataset_mix_ratio`: 默认为`0`. 该参数定义了如何进行数据集打混训练. 指定该参数时, 训练集会以`train_dataset_mix_ratio`倍数混合`train_dataset_mix_ds`指定的通用知识数据集, 使整体数据集长度达到`train_dataset_sample`.
104-
- `--train_dataset_mix_ds`: 默认为`ms-bench`. 用于防止知识遗忘的通用知识数据集.
103+
- `--train_dataset_mix_ratio`: 默认为`0.`. 该参数定义了如何进行数据集打混训练. 指定该参数时, 会混合训练集的`train_dataset_mix_ratio`倍数的`train_dataset_mix_ds`指定的通用知识数据集.
104+
- `--train_dataset_mix_ds`: 默认为`['ms-bench']`. 用于防止知识遗忘的通用知识数据集.
105105
- `--use_loss_scale`: 默认为`False`. 生效时会将Agent的部分字段(Action/Action Input部分)的loss权重加强以强化CoT, 对普通SFT场景没有任何效果.
106106

107107
### LoRA+微调参数

swift/llm/sft.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
4040
f'world_size: {world_size}, local_world_size: {local_world_size}')
4141
seed_everything(args.seed)
4242

43-
if args.gpu_memory_fraction:
43+
if args.gpu_memory_fraction is not None:
4444
for device_id in range(torch.cuda.device_count()):
4545
torch.cuda.set_per_process_memory_fraction(
4646
max(min(args.gpu_memory_fraction, 1.0), 0.01),
@@ -116,16 +116,9 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
116116
random_state,
117117
check_dataset_strategy=args.check_dataset_strategy)
118118
val_dataset_sample = args.val_dataset_sample
119-
mix_dataset_sample = 0 if not args.train_dataset_mix_ratio else round(
120-
len(train_dataset) * args.train_dataset_mix_ratio)
121119
if train_dataset is not None and args.train_dataset_sample >= 0:
122-
total_dataset_sample = min(args.train_dataset_sample,
120+
train_dataset_sample = min(args.train_dataset_sample,
123121
train_dataset.shape[0])
124-
train_dataset_sample = total_dataset_sample
125-
if args.train_dataset_mix_ratio:
126-
train_dataset_sample = round(
127-
1. / (1 + args.train_dataset_mix_ratio) * total_dataset_sample)
128-
mix_dataset_sample = total_dataset_sample - train_dataset_sample
129122
if train_dataset.shape[0] > train_dataset_sample:
130123
logger.info(f'train_dataset_sample: {train_dataset_sample}')
131124
train_idxs = random_state.permutation(train_dataset_sample)
@@ -139,8 +132,7 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
139132
val_idxs = random_state.permutation(val_dataset_sample)
140133
val_dataset = val_dataset.select(val_idxs)
141134

142-
train_dataset = handle_dataset_mixture(args, train_dataset,
143-
mix_dataset_sample)
135+
train_dataset = handle_dataset_mixture(args, train_dataset)
144136

145137
# add self-cognition dataset
146138
if args.self_cognition_sample > 0:

swift/llm/utils/argument.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.distributed as dist
1212
import transformers
13+
from datasets import Dataset as HfDataset
1314
from datasets import concatenate_datasets
1415
from packaging import version
1516
from torch import dtype as Dtype
@@ -76,7 +77,7 @@ class SftArguments:
7677
dataset_seed: int = 42
7778
dataset_test_ratio: float = 0.01
7879
train_dataset_sample: int = 20000 # -1: all dataset
79-
train_dataset_mix_ratio: Optional[float] = None
80+
train_dataset_mix_ratio: float = 0.
8081
train_dataset_mix_ds: List[str] = field(
8182
default_factory=lambda: ['ms-bench'])
8283
val_dataset_sample: Optional[int] = None # -1: all dataset
@@ -165,7 +166,7 @@ class SftArguments:
165166
adam_beta1: float = 0.9
166167
adam_beta2: float = 0.999
167168
learning_rate: Optional[float] = None
168-
weight_decay: float = 0.01
169+
weight_decay: float = 0.1
169170
gradient_accumulation_steps: Optional[int] = None
170171
max_grad_norm: float = 0.5
171172
predict_with_generate: bool = False
@@ -286,6 +287,8 @@ def __post_init__(self) -> None:
286287
set_model_type(self)
287288
if isinstance(self.dataset, str):
288289
self.dataset = [self.dataset]
290+
if isinstance(self.train_dataset_mix_ds, str):
291+
self.train_dataset_mix_ds = [self.train_dataset_mix_ds]
289292
register_custom_dataset(self)
290293
check_flash_attn(self)
291294
handle_generation_config(self)
@@ -653,6 +656,8 @@ def __post_init__(self) -> None:
653656
else:
654657
assert self.load_dataset_config is False, 'You need to first set `--load_args_from_ckpt_dir true`.'
655658
set_model_type(self)
659+
if isinstance(self.dataset, str):
660+
self.dataset = [self.dataset]
656661
register_custom_dataset(self)
657662
check_flash_attn(self)
658663
handle_generation_config(self)
@@ -661,8 +666,6 @@ def __post_init__(self) -> None:
661666
if self.template_type == 'AUTO':
662667
self.template_type = get_default_template_type(self.model_type)
663668
logger.info(f'Setting template_type: {self.template_type}')
664-
if isinstance(self.dataset, str):
665-
self.dataset = [self.dataset]
666669
has_dataset = (
667670
len(self.dataset) > 0 or len(self.custom_train_dataset_path) > 0
668671
or len(self.custom_val_dataset_path) > 0)
@@ -1078,25 +1081,38 @@ def handle_path(args: Union[SftArguments, InferArguments]) -> None:
10781081
setattr(args, k, value)
10791082

10801083

1084+
def _register_local_dataset(dataset_name: str, train_dataset_path: List[str],
1085+
val_dataset_path: List[str]) -> None:
1086+
register_dataset(
1087+
dataset_name,
1088+
'_',
1089+
train_dataset_path,
1090+
val_dataset_path,
1091+
get_function=get_custom_dataset,
1092+
exists_ok=True)
1093+
1094+
10811095
def register_custom_dataset(args: Union[SftArguments, InferArguments]) -> None:
1096+
dataset = []
1097+
for d in args.dataset:
1098+
if os.path.exists(d):
1099+
args.custom_train_dataset_path.append(d)
1100+
else:
1101+
dataset.append(d)
1102+
args.dataset = dataset
1103+
10821104
for key in ['custom_train_dataset_path', 'custom_val_dataset_path']:
10831105
value = getattr(args, key)
10841106
if isinstance(value, str):
10851107
setattr(args, key, [value])
10861108
if len(args.custom_train_dataset_path) == 0 and len(
10871109
args.custom_val_dataset_path) == 0:
10881110
return
1089-
register_dataset(
1090-
'_custom_dataset',
1091-
'_custom_dataset',
1092-
args.custom_train_dataset_path,
1093-
args.custom_val_dataset_path,
1094-
get_function=get_custom_dataset,
1095-
exists_ok=True)
1096-
if args.dataset is None:
1097-
args.dataset = ['_custom_dataset']
1098-
elif '_custom_dataset' not in args.dataset:
1099-
args.dataset.append('_custom_dataset')
1111+
1112+
dataset_name = '_custom_dataset'
1113+
_register_local_dataset(dataset_name, args.custom_train_dataset_path,
1114+
args.custom_val_dataset_path)
1115+
args.dataset.append(dataset_name)
11001116

11011117

11021118
def load_from_ckpt_dir(args: InferArguments) -> None:
@@ -1147,34 +1163,48 @@ def handle_generation_config(
11471163
)
11481164

11491165

1150-
def handle_dataset_mixture(args: SftArguments, train_dataset,
1151-
mix_dataset_sample) -> None:
1166+
def handle_dataset_mixture(args: SftArguments,
1167+
train_dataset: HfDataset) -> None:
11521168
if train_dataset is None:
11531169
return train_dataset
1154-
train_length = len(train_dataset)
1170+
if args.train_dataset_mix_ratio <= 0 or len(
1171+
args.train_dataset_mix_ds) == 0:
1172+
return train_dataset
1173+
11551174
random_state = np.random.RandomState(args.dataset_seed)
1156-
if mix_dataset_sample:
1157-
assert args.train_dataset_mix_ds is not None
1158-
train_dataset_mix_ds = [args.train_dataset_mix_ds] if isinstance(
1159-
args.train_dataset_mix_ds, str) else args.train_dataset_mix_ds
1160-
mixed_dataset = get_dataset(
1161-
train_dataset_mix_ds,
1162-
0.0,
1163-
random_state,
1164-
check_dataset_strategy=args.check_dataset_strategy)[0]
1165-
if len(mixed_dataset) < mix_dataset_sample:
1166-
logger.warn(
1167-
f'The length of dataset used for mixin: {train_dataset_mix_ds} are '
1168-
'lesser than the ratio required by the `train_dataset_mix_ratio` '
1169-
f'argument:{args.train_dataset_mix_ratio}'
1170-
f'the actual ratio is : {len(mixed_dataset)/float(train_length)}'
1171-
)
1175+
train_dataset_mix_ds = []
1176+
custom_mix_ds = []
1177+
for mix_ds in args.train_dataset_mix_ds:
1178+
if os.path.exists(mix_ds):
1179+
custom_mix_ds.append(mix_ds)
11721180
else:
1173-
train_idxs = random_state.permutation(mix_dataset_sample)
1174-
mixed_dataset = mixed_dataset.select(train_idxs)
1175-
return concatenate_datasets([train_dataset, mixed_dataset])
1181+
train_dataset_mix_ds.append(mix_ds)
1182+
1183+
if len(custom_mix_ds) > 0:
1184+
dataset_name = '_custom_mixture'
1185+
_register_local_dataset(dataset_name, custom_mix_ds, [])
1186+
train_dataset_mix_ds.append(dataset_name)
1187+
mix_dataset_sample = len(train_dataset) * args.train_dataset_mix_ratio
1188+
logger.info(f'train_dataset_mix_ds: {train_dataset_mix_ds}')
1189+
logger.info(
1190+
f'len(train_dataset): {len(train_dataset)}, mix_dataset_sample: {mix_dataset_sample}'
1191+
)
1192+
mixed_dataset = get_dataset(
1193+
train_dataset_mix_ds,
1194+
0.0,
1195+
random_state,
1196+
check_dataset_strategy=args.check_dataset_strategy)[0]
1197+
if len(mixed_dataset) < mix_dataset_sample:
1198+
logger.warn(
1199+
f'The length of dataset used for mixin: {train_dataset_mix_ds} are '
1200+
'lesser than the ratio required by the `train_dataset_mix_ratio` '
1201+
f'argument: {args.train_dataset_mix_ratio}. '
1202+
f'the actual ratio is: {len(mixed_dataset)/len(train_dataset):.6}.'
1203+
)
11761204
else:
1177-
return train_dataset
1205+
train_idxs = random_state.permutation(mix_dataset_sample)
1206+
mixed_dataset = mixed_dataset.select(train_idxs)
1207+
return concatenate_datasets([train_dataset, mixed_dataset])
11781208

11791209

11801210
def swift_to_peft_format(lora_checkpoint_path: str) -> str:

tests/llm/test_run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,17 @@ def test_custom_dataset(self):
182182
'alpaca.jsonl', 'alpaca2.csv', 'conversations.jsonl',
183183
'swift_pre.csv', 'swift_single.jsonl'
184184
]
185+
mixture_dataset = val_dataset_fnames
185186
folder = os.path.join(os.path.dirname(__file__), 'data')
186187
sft_args = SftArguments(
187188
model_type='qwen-7b-chat',
188189
custom_train_dataset_path=[
189190
os.path.join(folder, fname) for fname in train_dataset_fnames
190191
],
192+
train_dataset_mix_ds=[
193+
os.path.join(folder, fname) for fname in mixture_dataset
194+
],
195+
train_dataset_mix_ratio=2,
191196
check_dataset_strategy='warning')
192197
torch.cuda.empty_cache()
193198
best_model_checkpoint = sft_main(sft_args)['best_model_checkpoint']

0 commit comments

Comments
 (0)