Skip to content

Commit cd8500b

Browse files
[WIP]Support Q-Galore (modelscope#1440)
1 parent ef31538 commit cd8500b

File tree

8 files changed

+98
-7
lines changed

8 files changed

+98
-7
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ You can contact us and communicate with us by adding our group:
5555
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
5656

5757
## 🎉 News
58+
- 2024.07.19: Support [Q-Galore](https://arxiv.org/abs/2407.08296), this algorithm can reduce the training memory cost by 60% (qwen-7b-chat, full, 80G -> 35G), use `swift sft --model_type xxx --use_galore true --galore_quantization true` to begin!
5859
- 2024.07.17: Support newly released InternVL2 models: `model_type` are internvl2-1b, internvl2-40b, internvl2-llama3-76b. For best practices, refer to [here](docs/source_en/Multi-Modal/internvl-best-practice.md).
5960
- 2024.07.17: Support the training and inference of [NuminaMath-7B-TIR](https://huggingface.co/AI-MO/NuminaMath-7B-TIR). Use with model_type `numina-math-7b`.
6061
- 🔥2024.07.16: Support exporting for ollama and bitsandbytes. Use `swift export --model_type xxx --to_ollama true` or `swift export --model_type xxx --quant_method bnb --quant_bits 4`
@@ -454,7 +455,7 @@ swift sft \
454455
NPROC_PER_NODE=4 \
455456
CUDA_VISIBLE_DEVICES=0,1,2,3 \
456457
swift pt \
457-
--model_type qwen1half-7b-chat \
458+
--model_type qwen1half-7b \
458459
--dataset chinese_c4#10000 \
459460
--num_train_epochs 1 \
460461
--sft_type full \

README_CN.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ SWIFT具有丰富全面的文档,请查看我们的文档网站:
5656

5757

5858
## 🎉 新闻
59+
- 🔥2024.07.19: 支持[Q-Galore](https://arxiv.org/abs/2407.08296)算法, 该算法可以减少显存使用约60% (qwen-7b-chat, full, 80G -> 35G), 使用命令行:`swift sft --model_type xxx --use_galore true --galore_quantization true`来开始训练!
5960
- 2024.07.17: 支持InternVL2系列新模型: `model_type`分别为internvl2-1b, internvl2-40b, internvl2-llama3-76b. 最佳实践可以查看[这里](docs/source/Multi-Modal/internvl最佳实践.md).
6061
- 2024.07.17: 支持[NuminaMath-7B-TIR](https://www.modelscope.cn/models/AI-ModelScope/NuminaMath-7B-TIR)的训练和推理. model_type可以使用`numina-math-7b`.
6162
- 🔥2024.07.16: 支持ollama和bitsandbytes导出. 可以使用命令: `swift export --model_type xxx --to_ollama true`或者`swift export --model_type xxx --quant_method bnb --quant_bits 4`.
@@ -448,7 +449,7 @@ swift sft \
448449
NPROC_PER_NODE=4 \
449450
CUDA_VISIBLE_DEVICES=0,1,2,3 \
450451
swift pt \
451-
--model_type qwen1half-7b-chat \
452+
--model_type qwen1half-7b \
452453
--dataset chinese_c4#10000 \
453454
--num_train_epochs 1 \
454455
--sft_type full \

docs/source/LLM/命令行参数.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,13 @@
181181
- `--galore_proj_type: str` : 默认值`std`, GaLore矩阵分解类型.
182182
- `--galore_optim_per_parameter: bool` : 默认值False, 是否给每个Galore目标Parameter设定一个单独的optimizer.
183183
- `--galore_with_embedding: bool` : 默认值False, 是否对embedding应用GaLore.
184+
- `--galore_quantization` 是否使用q-galore. 默认值`False`.
185+
- `--galore_proj_quant`: 是否对SVD分解矩阵做量化, 默认`False`.
186+
- `--galore_proj_bits`: SVD量化bit数.
187+
- `--galore_proj_group_size`: SVD量化分组数.
188+
- `--galore_cos_threshold`: 投影矩阵更新的cos相似度阈值. 默认值0.4.
189+
- `--galore_gamma_proj`: 在投影矩阵逐渐相似后会拉长更新间隔, 本参数为每次拉长间隔的系数, 默认值2.
190+
- `--galore_queue_size`: 计算投影矩阵相似度的队列长度, 默认值5.
184191

185192
### LISA微调参数
186193

docs/source_en/LLM/Command-line-parameters.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,13 @@
183183
- `--galore_proj_type: str` : Default `std`, GaLore matrix decomposition type.
184184
- `--galore_optim_per_parameter: bool` : Default False, whether to set a separate optimizer for each Galore target Parameter.
185185
- `--galore_with_embedding: bool` : Default False, whether to apply GaLore to embedding.
186+
- `--galore_quantization`: Whether to use q-galore. Default value `False`.
187+
- `--galore_proj_quant`: Whether to quantize the SVD decomposition matrix, default `False`.
188+
- `--galore_proj_bits`: Number of bits for SVD quantization.
189+
- `--galore_proj_group_size`: Number of groups for SVD quantization.
190+
- `--galore_cos_threshold`: Cosine similarity threshold for updating the projection matrix. Default value 0.4.
191+
- `--galore_gamma_proj`: When the projection matrix gradually becomes similar, this parameter is the coefficient for extending the update interval each time, default value 2.
192+
- `--galore_queue_size`: Queue length for calculating projection matrix similarity, default value 5.
186193

187194
### LISA Fine-tuning Parameters
188195

scripts/benchmark/config/tuner.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,39 @@
138138
"sft_type": "full"
139139
}
140140
},
141+
{
142+
"name": "full+galore128+quantize",
143+
"requirements":{
144+
"gpu": "1",
145+
"ddp": "1"
146+
},
147+
"args": {
148+
"sft_type": "full",
149+
"use_galore": "true",
150+
"galore_rank": "128",
151+
"galore_update_proj_gap": "200",
152+
"galore_optim_per_parameter": "false",
153+
"galore_with_embedding": "false",
154+
"galore_quantization": "true"
155+
}
156+
},
157+
{
158+
"name": "full+galore128+quantize+proj_quant",
159+
"requirements":{
160+
"gpu": "1",
161+
"ddp": "1"
162+
},
163+
"args": {
164+
"sft_type": "full",
165+
"use_galore": "true",
166+
"galore_rank": "128",
167+
"galore_update_proj_gap": "200",
168+
"galore_optim_per_parameter": "false",
169+
"galore_with_embedding": "false",
170+
"galore_quantization": "true",
171+
"galore_proj_quant": "true"
172+
}
173+
},
141174
{
142175
"name": "full+galore128",
143176
"requirements":{

swift/llm/tuner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,13 @@ def prepare_model(model, args: SftArguments):
284284
galore_scale=args.galore_scale,
285285
proj_type=args.galore_proj_type,
286286
optim_per_parameter=args.galore_optim_per_parameter,
287+
quantize=args.galore_quantization,
288+
proj_quant=args.galore_proj_quant,
289+
proj_bits=args.galore_proj_bits,
290+
proj_group_size=args.galore_proj_group_size,
291+
cos_threshold=args.galore_cos_threshold,
292+
gamma_proj=args.galore_gamma_proj,
293+
queue_size=args.galore_queue_size,
287294
)
288295

289296
callbacks = []

swift/llm/utils/argument.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,13 @@ class SftArguments(ArgumentsBase):
561561
galore_proj_type: str = 'std'
562562
galore_optim_per_parameter: bool = False
563563
galore_with_embedding: bool = False
564+
galore_quantization: bool = False
565+
galore_proj_quant: bool = False
566+
galore_proj_bits: int = 4
567+
galore_proj_group_size: int = 256
568+
galore_cos_threshold: float = 0.4
569+
galore_gamma_proj: int = 2
570+
galore_queue_size: int = 5
564571

565572
# adalora
566573
adalora_target_r: int = 8

swift/trainers/optimizers/galore/utils.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import importlib
23
from dataclasses import dataclass
34
from typing import Any, Dict, List, Tuple, Union
45

@@ -41,6 +42,13 @@ class GaLoreConfig:
4142
galore_scale: float = 1.0
4243
proj_type: str = 'std'
4344
optim_per_parameter: bool = False
45+
quantize: bool = False
46+
proj_quant: bool = False
47+
proj_bits: int = 4
48+
proj_group_size: int = 256
49+
cos_threshold: float = 0.4
50+
gamma_proj: int = 2
51+
queue_size: int = 5
4452

4553

4654
class GaloreOptimizerWrapper(Optimizer):
@@ -82,6 +90,7 @@ def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, co
8290

8391
logger.info(f'Enable GaLore for weights in module: {module_name}')
8492
galore_params.append(module.weight)
93+
8594
id_galore_params = [id(p) for p in galore_params]
8695
galore_defaults = {
8796
'rank': config.rank,
@@ -90,9 +99,17 @@ def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, co
9099
'proj_type': config.proj_type,
91100
**defaults
92101
}
93-
optim_cls, optim_kwargs = get_optimizer(args)
94-
95-
if config.optim_per_parameter:
102+
if config.quantize:
103+
galore_defaults['quant'] = config.proj_quant
104+
galore_defaults['quant_n_bit'] = config.proj_bits
105+
galore_defaults['quant_group_size'] = config.proj_group_size
106+
galore_defaults['cos_threshold'] = config.cos_threshold
107+
galore_defaults['gamma_proj'] = config.gamma_proj
108+
galore_defaults['queue_size'] = config.queue_size
109+
optim_cls, optim_kwargs = get_optimizer(args, config)
110+
111+
if config.optim_per_parameter and not config.quantize:
112+
# q-galore does not support optim_per_parameter
96113
optimizer_dict = {}
97114
galore_defaults['update_proj_gap'] = galore_defaults['update_proj_gap'] * 2
98115
for p in model.parameters():
@@ -150,7 +167,7 @@ def create_optimizer_and_scheduler(model: nn.Module, args: TrainingArguments, co
150167
return optim, scheduler
151168

152169

153-
def get_optimizer(args: TrainingArguments) -> Tuple[Any, Any]:
170+
def get_optimizer(args: TrainingArguments, config: GaLoreConfig) -> Tuple[Any, Any]:
154171
# parse args.optim_args
155172
optim_args = {}
156173
if args.optim_args:
@@ -169,7 +186,18 @@ def get_optimizer(args: TrainingArguments) -> Tuple[Any, Any]:
169186
optimizer_cls = GaLoreAdafactor
170187
optimizer_kwargs.update({'scale_parameter': False, 'relative_step': False})
171188
elif args.optim in ('adamw_hf', 'adamw_torch'):
172-
from .adamw import GaLoreAdamW
189+
if config.quantize:
190+
assert importlib.util.find_spec("q_galore_torch") is not None, \
191+
'Please install q-galore by `pip install q_galore_torch`'
192+
from swift.utils import get_dist_setting
193+
_, _, world_size, _ = get_dist_setting()
194+
if world_size > 1:
195+
# from q_galore_torch import QGaLoreAdamW8bit_simulate as GaLoreAdamW
196+
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
197+
else:
198+
from q_galore_torch import QGaLoreAdamW8bit as GaLoreAdamW
199+
else:
200+
from .adamw import GaLoreAdamW
173201
optimizer_cls = GaLoreAdamW
174202
optimizer_kwargs.update(adam_kwargs)
175203
elif 'adamw' in args.optim and '8bit' in args.optim:

0 commit comments

Comments
 (0)