Skip to content

Commit 75ac11e

Browse files
Support bin packing and rope scaling (#1079)
1 parent ca4d2f6 commit 75ac11e

File tree

11 files changed

+63
-22
lines changed

11 files changed

+63
-22
lines changed

docs/source/LLM/命令行参数.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@
126126
- `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集. 格式参考: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
127127
- `--device_map_config_path`: 从本地文件中手动配置模型的device_map, 默认为None
128128

129+
### Long Context
129130

131+
- `--rope_scaling`: 默认值`None`, 支持`linear``dynamic`两种scaling方式.当`max_length`超过`max_position_embeddings`时使用.
130132

131133
### FSDP参数
132134

@@ -281,6 +283,7 @@ dpo参数继承了sft参数, 除此之外增加了以下参数:
281283
- `--lora_modules`: 默认为`[]`, 输入的格式为`'{lora_name}={lora_path}'`, e.g. `--lora_modules lora_name1=lora_path1 lora_name2=lora_path2`. `ckpt_dir`会以`f'default-lora={args.ckpt_dir}'`的形式加入args.lora_modules.
282284
- `--custom_register_path`: 默认为`None`. 传入`.py`文件, 用于注册模板、模型和数据集.
283285
- `--custom_dataset_info`: 默认为`None`, 传入外置dataset_info.json的路径、json字符串或者dict. 用于拓展数据集.
286+
- `--rope_scaling`: 默认值`None`, 支持`linear``dynamic`两种scaling方式, 当`max_length`超过`max_position_embeddings`时使用.
284287

285288

286289
## export 参数

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@
126126
- `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used to register custom datasets. The format example: https://github.com/modelscope/swift/blob/main/swift/llm/data/dataset_info.json
127127
- `device_map_config_path`: Manually configure the model's device map from a local file, defaults to None.
128128

129+
### Long Context
130+
131+
- `--rope_scaling`: Default `None`, Support `linear` and `dynamic` to scale positional embeddings. Use when `max_length` exceeds `max_position_embeddings`.
132+
129133
### FSDP Parameters
130134

131135
- `--fsdp`: Default value `''`, the FSDP type, please check [this documentation](https://huggingface.co/docs/transformers/v4.39.3/en/main_classes/trainer#transformers.TrainingArguments.fsdp) for details.
@@ -280,6 +284,7 @@ dpo parameters inherit from sft parameters, with the following added parameters:
280284
- `--lora_modules`: Default`[]`, the input format is `'{lora_name}={lora_path}'`, e.g. `--lora_modules lora_name1=lora_path1 lora_name2=lora_path2`. `ckpt_dir` will be added with `f'default-lora={args.ckpt_dir}'` by default.
281285
- `--custom_register_path`: Default is `None`. Pass in a `.py` file used to register templates, models, and datasets.
282286
- `--custom_dataset_info`: Default is `None`. Pass in the path to an external `dataset_info.json`, a JSON string, or a dictionary. Used for expanding datasets.
287+
- `--rope_scaling`: Default `None`, Support `linear` and `dynamic` to scale positional embeddings. Use when `max_length` exceeds `max_position_embeddings`.
283288

284289

285290
## export Parameters

requirements/framework.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
accelerate
2+
binpacking
23
dacite
34
jieba
45
matplotlib

swift/llm/dpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def llm_dpo(args: DPOArguments) -> str:
8282
}
8383
if args.use_flash_attn is not None:
8484
kwargs['use_flash_attn'] = args.use_flash_attn
85+
if args.rope_scaling:
86+
kwargs['rope_scaling'] = args.rope_scaling
87+
kwargs['max_length'] = args.max_length
8588
model, tokenizer = get_model_tokenizer(
8689
args.model_type,
8790
args.torch_dtype,

swift/llm/infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def prepare_model_template(args: InferArguments,
163163
kwargs['automodel_class'] = automodel_class
164164
if args.local_repo_path:
165165
kwargs['local_repo_path'] = args.local_repo_path
166+
if args.rope_scaling:
167+
kwargs['rope_scaling'] = args.rope_scaling
168+
kwargs['max_length'] = args.max_length
166169
model, tokenizer = get_model_tokenizer(
167170
args.model_type,
168171
args.torch_dtype,

swift/llm/orpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ def llm_orpo(args: ORPOArguments) -> str:
8383
}
8484
if args.use_flash_attn is not None:
8585
kwargs['use_flash_attn'] = args.use_flash_attn
86+
if args.rope_scaling:
87+
kwargs['rope_scaling'] = args.rope_scaling
88+
kwargs['max_length'] = args.max_length
8689
model, tokenizer = get_model_tokenizer(
8790
args.model_type,
8891
args.torch_dtype,

swift/llm/sft.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def llm_sft(args: SftArguments) -> Dict[str, Union[str, Any]]:
108108
elif args.quant_method == 'gptq':
109109
kwargs['is_gptq'] = True
110110

111+
if args.rope_scaling:
112+
kwargs['rope_scaling'] = args.rope_scaling
113+
kwargs['max_length'] = args.max_length
114+
111115
model, tokenizer = get_model_tokenizer(
112116
args.model_type,
113117
args.torch_dtype,

swift/llm/simpo.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def llm_simpo(args: SimPOArguments) -> str:
8282
}
8383
if args.use_flash_attn is not None:
8484
kwargs['use_flash_attn'] = args.use_flash_attn
85+
if args.rope_scaling:
86+
kwargs['rope_scaling'] = args.rope_scaling
87+
kwargs['max_length'] = args.max_length
8588
model, tokenizer = get_model_tokenizer(
8689
args.model_type,
8790
args.torch_dtype,

swift/llm/utils/argument.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ class SftArguments(ArgumentsBase):
487487
# Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'loftq', 'true', 'false']
488488
init_lora_weights: str = 'true'
489489

490+
# rope-scaling
491+
rope_scaling: Literal['linear', 'dynamic'] = None
492+
490493
# BOFT
491494
boft_block_size: int = 4
492495
boft_block_num: int = 0
@@ -1036,6 +1039,9 @@ class InferArguments(ArgumentsBase):
10361039
num_beams: int = 1
10371040
stop_words: List[str] = None
10381041

1042+
# rope-scaling
1043+
rope_scaling: Literal['linear', 'dynamic'] = None
1044+
10391045
# other
10401046
use_flash_attn: Optional[bool] = None
10411047
ignore_args_error: bool = False # True: notebook compatibility

swift/llm/utils/model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import inspect
3+
import math
34
import os
45
import sys
56
from contextlib import nullcontext
@@ -864,6 +865,14 @@ def get_model_tokenizer_from_repo(model_dir: str,
864865
tokenizer.placeholder_tokens = placeholder_tokens
865866
tokenizer.placeholder_tokens_id = [tokenizer.convert_tokens_to_ids(token) for token in placeholder_tokens]
866867
model = None
868+
869+
rope_scaling = kwargs.pop('rope_scaling', None)
870+
max_position_embeddings = getattr(model_config, 'max_position_embeddings', None)
871+
if rope_scaling and max_position_embeddings:
872+
max_length = kwargs.get('max_length') or max_position_embeddings
873+
rope_scaling_factor = max(float(math.ceil(max_length / max_position_embeddings)), 1.0)
874+
setattr(model_config, 'rope_scaling', {'type': rope_scaling, 'factor': rope_scaling_factor})
875+
867876
if load_model:
868877
if kwargs.get('use_unsloth', False):
869878
assert is_unsloth_available(), 'please install unsloth if using `use_unsloth=True`'

0 commit comments

Comments
 (0)