Skip to content

Commit 1fa180f

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP into update_predict_new
2 parents 7a8e9c9 + 77f6e98 commit 1fa180f

File tree

23 files changed

+1250
-121
lines changed

23 files changed

+1250
-121
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"model_name_or_path": "facebook/llama-7b",
3+
"dataset_name_or_path": "./data",
4+
"output_dir": "./checkpoints/vera_ckpts",
5+
"per_device_train_batch_size": 4,
6+
"gradient_accumulation_steps": 4,
7+
"per_device_eval_batch_size": 8,
8+
"eval_accumulation_steps":16,
9+
"num_train_epochs": 1,
10+
"learning_rate": 3e-04,
11+
"warmup_steps": 30,
12+
"logging_steps": 1,
13+
"evaluation_strategy": "epoch",
14+
"save_strategy": "epoch",
15+
"src_length": 1024,
16+
"max_length": 2048,
17+
"fp16": true,
18+
"fp16_opt_level": "O2",
19+
"do_train": true,
20+
"do_eval": true,
21+
"disable_tqdm": true,
22+
"load_best_model_at_end": true,
23+
"eval_with_do_generation": false,
24+
"metric_for_best_model": "accuracy",
25+
"recompute": true,
26+
"save_total_limit": 10,
27+
"tensor_parallel_degree": 1,
28+
"pipeline_parallel_degree": 1,
29+
"vera": true,
30+
"zero_padding": false,
31+
"use_flash_attention": false
32+
}

llm/docs/quantization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ python run_finetune.py ./config/llama/awq_argument.json
7575

7676
<div>
7777

78-
- `quant_type`: PTQ,QAT 量化类型,默认为 A8W8。支持 A8W8,WINT4,WINT8:A8W8指对激活(输入)进行 INT8量化,对模型权重进行 INT8量化;WINT4指仅对模型权重进行 INT4量化,后续使用 WeightOnly 进行推理;WINT8指仅对模型权重进行 INT8量化,后续使用 WeightOnly 进行推理。
78+
- `quant_type`: PTQQAT 量化类型,默认为 A8W8。支持 A8W8WINT4,WINT8:A8W8指对激活(输入)进行 INT8量化,对模型权重进行 INT8量化;WINT4指仅对模型权重进行 INT4量化,后续使用 WeightOnly 进行推理;WINT8指仅对模型权重进行 INT8量化,后续使用 WeightOnly 进行推理。
7979
- `do_ptq`: 是否进行 PTQ 量化,默认为 False。
8080
- `weight_quant_method`: 权重量化方式,现可选 groupwise 或者 abs_max_channel_wise。
8181
- `ptq_step`: PTQ 量化步数,也即模型前向次数,默认为32。

llm/predict/export_model.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,16 @@
1919
import paddle
2020
from paddle.distributed import fleet
2121
from predict.predictor import ModelArgument, PredictorArgument, create_predictor
22-
from tqdm import tqdm
2322
from utils.utils import generate_rank_mapping, get_infer_model_path
2423

2524
from paddlenlp.trainer import PdArgumentParser
26-
from paddlenlp.utils.log import logger
2725

2826

2927
@dataclass
3028
class ExportArgument:
3129
output_path: str = field(default=None, metadata={"help": "The output path of model."})
3230

3331

34-
def load_inference_model(model_path, model_name, param_name, exe):
35-
model_abs_path = os.path.join(model_path, model_name)
36-
param_abs_path = os.path.join(model_path, param_name)
37-
if os.path.exists(model_abs_path) and os.path.exists(param_abs_path):
38-
return paddle.static.io.load_inference_model(model_path, exe, model_name, param_name)
39-
else:
40-
return paddle.static.io.load_inference_model(model_path, exe)
41-
42-
43-
def validate_pdmodel(model_path, model_prefix, device):
44-
paddle.enable_static()
45-
if device == "gpu":
46-
place = paddle.CUDAPlace(0)
47-
elif device == "cpu":
48-
place = paddle.CPUPlace()
49-
else:
50-
place = paddle.CustomPlace(device, 0)
51-
exe = paddle.static.Executor(place)
52-
scope = paddle.static.Scope()
53-
54-
with paddle.static.scope_guard(scope):
55-
net_program, feed_target_names, fetch_targets = paddle.static.io.load_inference_model(
56-
os.path.join(model_path, model_prefix), exe
57-
)
58-
59-
if not paddle.framework.use_pir_api():
60-
for block in net_program.blocks:
61-
ops: list[paddle.framework.Operator] = block.ops
62-
for op in tqdm(ops, desc="checking the validation of ops"):
63-
if op.type.lower() == "print":
64-
logger.warning(
65-
f"UNEXPECTED OP<{op.type}> which will reduce the performace of the static model"
66-
)
67-
68-
6932
def main():
7033
parser = PdArgumentParser((PredictorArgument, ModelArgument, ExportArgument))
7134
predictor_args, model_args, export_args = parser.parse_args_into_dataclasses()
@@ -104,7 +67,6 @@ def main():
10467

10568
if tensor_parallel_degree > 1:
10669
export_args.output_path = os.path.join(export_args.output_path, f"rank_{tensor_parallel_rank}")
107-
validate_pdmodel(export_args.output_path, predictor_args.model_prefix, predictor_args.device)
10870

10971
if predictor_args.device == "npu":
11072
from npu.llama.export_utils import process_params

llm/predict/predictor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,9 @@ def stream_predict(self, inputs: dict[str, paddle.Tensor]):
331331
bos_token_id=self.tokenizer.bos_token_id,
332332
eos_token_id=get_eos_token_id(self.tokenizer, self.generation_config),
333333
pad_token_id=self.tokenizer.pad_token_id,
334-
decode_strategy="greedy_search"
335-
if self.config.top_k == 1 and self.config.top_p == 1.0
336-
else self.config.decode_strategy,
334+
decode_strategy=(
335+
"greedy_search" if self.config.top_k == 1 and self.config.top_p == 1.0 else self.config.decode_strategy
336+
),
337337
temperature=self.config.temperature,
338338
top_k=self.config.top_k,
339339
top_p=self.config.top_p,
@@ -1238,7 +1238,9 @@ def predict(self, input_texts: str | list[str], return_tokens=False):
12381238
def _preprocess(self, source):
12391239
BlockInferencePredictorMixin._preprocess(self, source)
12401240
for i, text in enumerate(source):
1241-
tokens = self.tokenizer(text, return_tensors="np", padding=False, truncation=True, max_length=(self.config.src_length))
1241+
tokens = self.tokenizer(
1242+
text, return_tensors="np", padding=False, truncation=True, max_length=(self.config.src_length)
1243+
)
12421244
input_ids = tokens["input_ids"][0]
12431245
length = len(input_ids)
12441246
need_block_nums = (
@@ -1650,8 +1652,8 @@ def predict():
16501652
target_texts.append("")
16511653

16521654
else:
1653-
source_texts = ["解释一下“温故而知新”", "你好,请问你是谁?"]
1654-
target_texts = ["", ""]
1655+
source_texts = ["你好,请问你是谁?"] * predictor_args.batch_size
1656+
target_texts = [""] * predictor_args.batch_size
16551657

16561658
batch_source_texts = batchfy_text(source_texts, predictor_args.batch_size)
16571659
batch_target_texts = batchfy_text(target_texts, predictor_args.batch_size)

llm/run_finetune.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import json
1515
import os
1616
import sys
17-
import inspect
1817
from functools import partial
1918

2019
import paddle
@@ -42,7 +41,14 @@
4241
load_dataset,
4342
)
4443
from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL
45-
from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM
44+
from paddlenlp.peft import (
45+
LoRAConfig,
46+
LoRAModel,
47+
PrefixConfig,
48+
PrefixModelForCausalLM,
49+
VeRAConfig,
50+
VeRAModel,
51+
)
4652
from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint
4753
from paddlenlp.trainer.trainer_callback import TrainerState
4854
from paddlenlp.transformers import (
@@ -51,9 +57,9 @@
5157
AutoModelForCausalLMPipe,
5258
AutoTokenizer,
5359
Llama3Tokenizer,
54-
LlamaTokenizer,
5560
LlamaForCausalLM,
5661
LlamaForCausalLMPipe,
62+
LlamaTokenizer,
5763
)
5864
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
5965
from paddlenlp.utils.log import logger
@@ -82,7 +88,6 @@ def main():
8288
raise ValueError(
8389
"--do_train, --do_ptq, --do_gptq and --do_qat cannot work at the same time. Please choose only one at a time"
8490
)
85-
8691

8792
# Setup GPU & distributed training
8893
paddle.set_device(training_args.device)
@@ -168,9 +173,7 @@ def main():
168173
model = model_class.from_config(model_config, dtype=dtype)
169174

170175
if model_args.flash_mask and (not data_args.zero_padding or not model.config.use_flash_attention):
171-
logger.warning(
172-
"`flash_mask` must use with zero padding and flash attention."
173-
)
176+
logger.warning("`flash_mask` must use with zero padding and flash attention.")
174177
data_args.zero_padding = True
175178
model.config.use_flash_attention = True
176179

@@ -346,12 +349,16 @@ def neft_post_hook(module, input, output):
346349
"Zero Padding data stream is only implemented for LLaMA, Bloom, ChatGLM, QWen and Mistral so far."
347350
)
348351
train_ds = (
349-
train_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
352+
train_ds.map(
353+
partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)
354+
)
350355
if train_ds is not None
351356
else None
352357
)
353358
ptq_ds = (
354-
ptq_ds.map(partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask))
359+
ptq_ds.map(
360+
partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)
361+
)
355362
if ptq_ds is not None
356363
else None
357364
)
@@ -362,7 +369,14 @@ def neft_post_hook(module, input, output):
362369
)
363370
eval_zero_padding = False
364371
dev_ds = (
365-
dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, zero_padding=eval_zero_padding, flash_mask=model_args.flash_mask))
372+
dev_ds.map(
373+
partial(
374+
trans_func,
375+
is_test=data_args.eval_with_do_generation,
376+
zero_padding=eval_zero_padding,
377+
flash_mask=model_args.flash_mask,
378+
)
379+
)
366380
if dev_ds is not None
367381
else None
368382
)
@@ -486,6 +500,20 @@ def compute_metrics_do_generation(eval_preds):
486500
"bleu4": bleu4.score(),
487501
}
488502

503+
if model_args.vera:
504+
target_modules = get_lora_target_modules(model)
505+
vera_config = VeRAConfig(
506+
target_modules=target_modules,
507+
r=model_args.vera_rank,
508+
vera_alpha=model_args.vera_rank,
509+
dtype=dtype,
510+
base_model_name_or_path=model_args.model_name_or_path,
511+
pissa_init=True,
512+
)
513+
model = VeRAModel(model, vera_config)
514+
model.mark_only_vera_as_trainable(notfreezeB=True)
515+
model.print_trainable_parameters()
516+
489517
# Create trainer
490518
max_length = (
491519
data_args.max_length

llm/tools/merge_vera_params.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import paddle
19+
20+
from paddlenlp.peft import VeRAConfig, VeRAModel
21+
from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
22+
from paddlenlp.utils.env import CONFIG_NAME
23+
24+
25+
def parse_arguments():
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument("--model_name_or_path", default=None, help="The directory of pretrained model.")
28+
parser.add_argument("--vera_path", default="", help="The directory of VeRA parameters. Default to None")
29+
parser.add_argument(
30+
"--merge_vera_model_path",
31+
default="",
32+
help="The directory of merged parameters. Default to None",
33+
)
34+
parser.add_argument("--device", type=str, default="gpu", help="Device")
35+
parser.add_argument(
36+
"--low_gpu_mem", type=bool, default=True, help="Whether to use low gpu memory. Default to False"
37+
)
38+
return parser.parse_args()
39+
40+
41+
def weight_process(name, vera_config, state_dict):
42+
weight = state_dict.pop(name + ".weight").cuda()
43+
vera_A = state_dict.pop(name + ".vera_A").cuda()
44+
vera_B = state_dict.pop(name + ".vera_B").cuda()
45+
vera_b = state_dict.pop(name + ".vera_b").cuda()
46+
vera_d = state_dict.pop(name + ".vera_d").cuda()
47+
diag_b = paddle.diag(vera_b)
48+
diag_d = paddle.diag(vera_d)
49+
50+
scaling = vera_config.vera_alpha / vera_config.r
51+
state_dict[name + ".weight"] = (weight + vera_A @ diag_d @ vera_B @ diag_b * scaling).cpu()
52+
53+
54+
def merge():
55+
args = parse_arguments()
56+
paddle.set_device(args.device)
57+
58+
vera_config = VeRAConfig.from_pretrained(args.vera_path)
59+
if vera_config.base_model_name_or_path is None:
60+
if args.model_name_or_path is not None:
61+
raise ValueError("We can not find a valid model_name_or_path.")
62+
else:
63+
vera_config.base_model_name_or_path = args.model_name_or_path
64+
65+
if os.path.isfile(os.path.join(args.vera_path, CONFIG_NAME)):
66+
config = AutoConfig.from_pretrained(args.vera_path)
67+
elif args.model_name_or_path is not None:
68+
config = AutoConfig.from_pretrained(args.model_name_or_path)
69+
else:
70+
raise ValueError(
71+
f"We can not find config.json in vera_path: {args.vera_path} or find a valid model_name_or_path."
72+
)
73+
config.dtype = vera_config.dtype
74+
if (
75+
vera_config.dtype == "bfloat16" or config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
76+
) and args.device == "cpu":
77+
raise ValueError("We can not apply bfloat16 or nf4/fp4 vera merge on cpu.")
78+
79+
# with device_guard() will cause SVD decomposition to fail
80+
model = AutoModelForCausalLM.from_pretrained(
81+
vera_config.base_model_name_or_path,
82+
config=config,
83+
low_cpu_mem_usage=True,
84+
)
85+
model = VeRAModel.from_pretrained(model=model, vera_path=args.vera_path, vera_config=vera_config)
86+
87+
model.eval()
88+
model_state_dict = model.model.state_dict()
89+
vera_name_list = []
90+
for key in model_state_dict.keys():
91+
if "vera_A" in key:
92+
vera_name_list.append(key[:-7])
93+
94+
for name in vera_name_list:
95+
weight_process(name, vera_config, model_state_dict)
96+
97+
model.model.save_pretrained(args.merge_vera_model_path, state_dict=model_state_dict)
98+
tokenizer = AutoTokenizer.from_pretrained(vera_config.base_model_name_or_path)
99+
tokenizer.save_pretrained(args.merge_vera_model_path)
100+
101+
102+
if __name__ == "__main__":
103+
merge()

llm/utils/argument.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ class ModelArgument:
200200
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+ technique"})
201201
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
202202

203+
# vera related parameters
204+
vera: bool = field(default=False, metadata={"help": "Whether to use vera technique"})
205+
vera_rank: int = field(default=8, metadata={"help": "Vera attention dimension"})
206+
203207
# prefix tuning related parameters
204208
prefix_tuning: bool = field(default=False, metadata={"help": "Whether to use Prefix technique"})
205209
prefix_path: str = field(default=None, metadata={"help": "Initialize prefix state dict."})
@@ -213,9 +217,7 @@ class ModelArgument:
213217
aistudio_token: str = field(default=None, metadata={"help": "The token of aistudio"})
214218
neftune: bool = field(default=False, metadata={"help": "Whether to apply NEFT"})
215219
neftune_noise_alpha: float = field(default=5.0, metadata={"help": "NEFT noise alpha"})
216-
flash_mask: bool = field(
217-
default=False, metadata={"help": "Whether to use flash_mask in flash attention."}
218-
)
220+
flash_mask: bool = field(default=False, metadata={"help": "Whether to use flash_mask in flash attention."})
219221

220222

221223
@dataclass

paddlenlp/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,17 @@
1414

1515
import os
1616
import sys
17+
from datetime import datetime
1718

1819
PADDLENLP_STABLE_VERSION = "PADDLENLP_STABLE_VERSION"
1920

2021

2122
__version__ = "3.0.0b0.post"
2223
if os.getenv(PADDLENLP_STABLE_VERSION):
2324
__version__ = __version__.replace(".post", "")
25+
else:
26+
formatted_date = datetime.now().date().strftime("%Y%m%d")
27+
__version__ = __version__.replace(".post", ".post{}".format(formatted_date))
2428

2529
if "datasets" in sys.modules.keys():
2630
from paddlenlp.utils.log import logger

0 commit comments

Comments
 (0)