Skip to content

Commit 7370d72

Browse files
wtmlonlugimzzz
andauthored
add rslora & lora+ (#8111)
* add rslora & lora+ * remove print * reformat * update * fix bug * add rslora+ ci * remove magic number * update * empty --------- Co-authored-by: lugimzzz <[email protected]>
1 parent 1cd270a commit 7370d72

File tree

7 files changed

+120
-7
lines changed

7 files changed

+120
-7
lines changed

llm/argument.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class ModelArgument:
126126
lora: bool = field(default=False, metadata={"help": "Whether to use LoRA technique"})
127127
lora_path: str = field(default=None, metadata={"help": "Initialize lora state dict."})
128128
lora_rank: int = field(default=8, metadata={"help": "Lora attention dimension"})
129+
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
130+
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
129131

130132
# prefix tuning related parameters
131133
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})

llm/finetune_generation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,9 @@ def neft_post_hook(module, input, output):
418418
lora_config = LoRAConfig(
419419
target_modules=target_modules,
420420
r=model_args.lora_rank,
421-
lora_alpha=2 * model_args.lora_rank,
421+
lora_alpha=2 * model_args.lora_rank if not model_args.rslora else 4,
422+
rslora=model_args.rslora,
423+
lora_plus_scale=model_args.lora_plus_scale,
422424
merge_weights=False,
423425
tensor_parallel_degree=training_args.tensor_parallel_degree,
424426
dtype=dtype,

paddlenlp/peft/lora/lora_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ class LoRAConfig:
7272
},
7373
)
7474
do_qat: bool = field(default=False, metadata={"help": "Whether the lora model would do quant-aware training"})
75+
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
76+
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
7577
base_model_name_or_path: Optional[str] = field(
7678
default=None, metadata={"help": "The name of the base model to use."}
7779
)

paddlenlp/peft/lora/lora_layers.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __init__(
3535
lora_alpha: int = 1,
3636
lora_dropout: float = 0.0,
3737
merge_weights: bool = True,
38+
rslora: bool = False,
39+
lora_plus_scale: float = 1.0,
3840
**kwargs
3941
):
4042
nn.Linear.__init__(self, in_features, out_features, **kwargs)
@@ -62,9 +64,16 @@ def __init__(
6264
shape=[r, out_features],
6365
dtype=self._dtype,
6466
is_bias=False,
65-
default_initializer=nn.initializer.Constant(value=0.0),
67+
attr=paddle.ParamAttr(
68+
initializer=paddle.nn.initializer.Constant(value=0.0),
69+
learning_rate=lora_plus_scale,
70+
),
6671
)
67-
self.scaling = self.lora_alpha / self.r
72+
73+
if not rslora:
74+
self.scaling = self.lora_alpha / self.r
75+
else:
76+
self.scaling = self.lora_alpha / math.sqrt(self.r)
6877

6978
# Freezing the pre-trained weight matrix
7079
self.weight.stop_gradient = True
@@ -104,6 +113,8 @@ def __init__(
104113
r: int = 0,
105114
lora_alpha: int = 1,
106115
lora_dropout: float = 0.0,
116+
rslora: bool = False,
117+
lora_plus_scale: float = 1.0,
107118
merge_weights: bool = True,
108119
**kwargs
109120
):
@@ -137,12 +148,19 @@ def __init__(
137148
shape=[r, self.out_features],
138149
dtype=self._dtype,
139150
is_bias=False,
140-
default_initializer=nn.initializer.Constant(value=0.0),
151+
attr=paddle.ParamAttr(
152+
initializer=paddle.nn.initializer.Constant(value=0.0),
153+
learning_rate=lora_plus_scale,
154+
),
141155
)
156+
142157
self.lora_A.is_distributed = True
143158
self.lora_A.split_axis = 0
144159
self.lora_B.is_distributed = False
145-
self.scaling = self.lora_alpha / self.r
160+
if not rslora:
161+
self.scaling = self.lora_alpha / self.r
162+
else:
163+
self.scaling = self.lora_alpha / math.sqrt(self.r)
146164

147165
# Freezing the pre-trained weight matrix
148166
self.weight.stop_gradient = True
@@ -208,6 +226,8 @@ def __init__(
208226
r: int = 0,
209227
lora_alpha: int = 1,
210228
lora_dropout: float = 0.0,
229+
rslora: bool = False,
230+
lora_plus_scale: float = 1.0,
211231
merge_weights: bool = True,
212232
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
213233
**kwargs
@@ -241,11 +261,18 @@ def __init__(
241261
shape=[r, self.output_size_per_partition],
242262
dtype=self._dtype,
243263
is_bias=False,
244-
default_initializer=nn.initializer.Constant(value=0.0),
264+
attr=paddle.ParamAttr(
265+
initializer=paddle.nn.initializer.Constant(value=0.0),
266+
learning_rate=lora_plus_scale,
267+
),
245268
)
269+
246270
self.lora_B.is_distributed = True
247271
self.lora_B.split_axis = 1
248-
self.scaling = self.lora_alpha / self.r
272+
if not rslora:
273+
self.scaling = self.lora_alpha / self.r
274+
else:
275+
self.scaling = self.lora_alpha / math.sqrt(self.r)
249276

250277
# Freezing the pre-trained weight matrix
251278
self.weight.stop_gradient = True

paddlenlp/peft/lora/lora_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
382382
lora_alpha=lora_config.lora_alpha,
383383
lora_dropout=lora_config.lora_dropout,
384384
merge_weights=lora_config.merge_weights,
385+
rslora=lora_config.rslora,
386+
lora_plus_scale=lora_config.lora_plus_scale,
385387
bias_attr=False if module.bias is None else None,
386388
)
387389
if isinstance(module, nn.Conv2D):
@@ -412,6 +414,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
412414
r=lora_config.r,
413415
lora_alpha=lora_config.lora_alpha,
414416
lora_dropout=lora_config.lora_dropout,
417+
rslora=lora_config.rslora,
418+
lora_plus_scale=lora_config.lora_plus_scale,
415419
merge_weights=lora_config.merge_weights,
416420
lora_A_weight_attr=paddle.ParamAttr(
417421
initializer=nn.initializer.KaimingUniform(
@@ -437,6 +441,8 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora)
437441
r=lora_config.r,
438442
lora_alpha=lora_config.lora_alpha,
439443
lora_dropout=lora_config.lora_dropout,
444+
rslora=lora_config.rslora,
445+
lora_plus_scale=lora_config.lora_plus_scale,
440446
merge_weights=lora_config.merge_weights,
441447
)
442448
# Lora column parallel will spilt lora A matrix

tests/fixtures/llm/lora.yaml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,51 @@ lora:
4141
baichuan:
4242
model_name_or_path: __internal_testing__/tiny-fused-baichuan
4343

44+
rslora_plus:
45+
base:
46+
dataset_name_or_path: "./data"
47+
per_device_train_batch_size: 4
48+
gradient_accumulation_steps: 4
49+
per_device_eval_batch_size: 8
50+
eval_accumulation_steps: 16
51+
num_train_epochs: 3
52+
learning_rate: 3e-04
53+
warmup_steps: 30
54+
logging_steps: 1
55+
evaluation_strategy: "epoch"
56+
save_strategy: "epoch"
57+
src_length: 1024
58+
max_length: 2048
59+
fp16: true
60+
fp16_opt_level: "O2"
61+
do_train: true
62+
do_eval: true
63+
disable_tqdm: true
64+
load_best_model_at_end: true
65+
eval_with_do_generation: false
66+
metric_for_best_model: "accuracy"
67+
recompute: true
68+
save_total_limit: 1
69+
tensor_parallel_degree: 1
70+
pipeline_parallel_degree: 1
71+
lora: true
72+
lora_plus_scale: 4
73+
rslora: true
74+
75+
default:
76+
llama:
77+
model_name_or_path: __internal_testing__/tiny-random-llama
78+
chatglm:
79+
model_name_or_path: __internal_testing__/tiny-fused-chatglm
80+
chatglm2:
81+
model_name_or_path: __internal_testing__/tiny-fused-chatglm2
82+
bloom:
83+
model_name_or_path: __internal_testing__/tiny-fused-bloom
84+
qwen:
85+
model_name_or_path: __internal_testing__/tiny-fused-qwen
86+
baichuan:
87+
model_name_or_path: __internal_testing__/tiny-fused-baichuan
88+
4489
inference-predict:
4590
default:
4691
mode: dynamic

tests/llm/test_lora.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,35 @@ def test_lora(self):
7979

8080
self.run_predictor({"inference_model": False})
8181

82+
def test_rslora_plus(self):
83+
self.disable_static()
84+
paddle.set_default_dtype("float32")
85+
86+
lora_config = load_test_config(self.config_path, "rslora_plus", self.model_dir)
87+
lora_config["output_dir"] = self.output_dir
88+
lora_config["dataset_name_or_path"] = self.data_dir
89+
90+
with argv_context_guard(lora_config):
91+
from finetune_generation import main
92+
93+
main()
94+
95+
# merge weights
96+
merge_lora_weights_config = {
97+
"lora_path": lora_config["output_dir"],
98+
"merge_lora_model_path": lora_config["output_dir"],
99+
}
100+
with argv_context_guard(merge_lora_weights_config):
101+
from merge_lora_params import merge
102+
103+
merge()
104+
105+
# TODO(wj-Mcat): disable chatglm2 test temporarily
106+
if self.model_dir not in ["qwen", "baichuan", "chatglm2"]:
107+
self.run_predictor({"inference_model": True})
108+
109+
self.run_predictor({"inference_model": False})
110+
82111

83112
# @parameterized_class(
84113
# ["model_dir"],

0 commit comments

Comments
 (0)