Skip to content

Commit 90cbb21

Browse files
authored
add belle finetune (#5836)
1 parent 8561751 commit 90cbb21

File tree

5 files changed

+119
-10
lines changed

5 files changed

+119
-10
lines changed

examples/language_model/chatglm/README.md

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ python -m pip install paddlepaddle-gpu==0.0.0.post112 -f https://www.paddlepaddl
2020
```
2121
python -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py \
2222
--model_name_or_path THUDM/chatglm-6b \
23-
--task_path AdvertiseGen/ \
23+
--task_name_or_path AdvertiseGen/ \
2424
--max_steps 3000 \
2525
--learning_rate 3e-5 \
2626
--warmup_steps 20 \
@@ -45,7 +45,7 @@ python -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py \
4545
其中参数释义如下:
4646

4747
- `model_name_or_path`: 预训练模型内置名称或者模型所在目录,默认为`THUDM/chatglm-6b`
48-
- `task_path`: 数据集存储目录。
48+
- `task_name_or_path`: 数据集存储目录。
4949
- `max_steps`: 模型训练步数。
5050
- `learning_rate`: 参数更新的学习率。
5151
- `warmup_steps`: 学习率热启的步数。
@@ -64,6 +64,31 @@ python -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py \
6464
- `do_eval`: 是否评估模型。
6565
- `tensor_parallel_degree`: 模型并行数量。
6666

67+
## BelleGroup/school_math_0.25M
68+
69+
```
70+
python -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py \
71+
--output_dir ./checkpoints/chatglm-6b \
72+
--per_device_train_batch_size 4 \
73+
--gradient_accumulation_steps 8 \
74+
--save_steps 500 \
75+
--model_name_or_path THUDM/chatglm-6b \
76+
--task_name_or_path school_math_0.25M \
77+
--num_train_epochs 2 \
78+
--learning_rate 3e-5 \
79+
--warmup_ratio 0.03 \
80+
--logging_steps 1 \
81+
--evaluation_strategy no \
82+
--src_length 128 \
83+
--tgt_length 512 \
84+
--fp16 \
85+
--fp16_opt_level O2 \
86+
--recompute True \
87+
--do_train \
88+
--disable_tqdm True \
89+
--metric_for_best_model ppl \
90+
--greater_is_better False
91+
```
6792

6893
## 模型预测
6994

examples/language_model/chatglm/data.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,65 @@ def convert_example(example, tokenizer, data_args, is_test=True):
7979
"labels": labels,
8080
}
8181
return inputs
82+
83+
84+
def custom_instruction_convert_example(example, tokenizer, data_args, is_test=True):
85+
instruction = ""
86+
input = ""
87+
response = ""
88+
if "instruction" in example and "output" in example:
89+
instruction = example["instruction"]
90+
response = example["output"]
91+
else:
92+
assert False, "instruction and output are not in the input dictionary."
93+
if "input" in example["input"]:
94+
input = example["input"]
95+
96+
if "chat" in data_args.task_name_or_path:
97+
prompt = instruction + input
98+
else:
99+
prompt = "Human: " + instruction + input + "\n Assistant: "
100+
101+
# dataset for evaluation
102+
if is_test:
103+
inputs = {
104+
**tokenizer(prompt, max_length=data_args.src_length, truncation=True, padding="max_length"),
105+
"labels": tokenizer(response, max_length=data_args.tgt_length, truncation=True, padding="max_length")[
106+
"input_ids"
107+
],
108+
}
109+
# dataset for training
110+
else:
111+
src_ids = tokenizer(
112+
prompt,
113+
add_special_tokens=False,
114+
max_length=data_args.src_length - 1,
115+
truncation=True,
116+
truncation_side="left",
117+
)["input_ids"]
118+
tgt_ids = tokenizer(
119+
response,
120+
add_special_tokens=False,
121+
max_length=data_args.tgt_length - 2,
122+
truncation=True,
123+
truncation_side="right",
124+
)["input_ids"]
125+
126+
input_ids = tokenizer.build_inputs_with_special_tokens(src_ids, tgt_ids)
127+
128+
context_length = input_ids.index(tokenizer.bos_token_id)
129+
mask_position = context_length - 1
130+
131+
attention_mask = np.tri(len(input_ids), len(input_ids))
132+
attention_mask[:, :context_length] = 1
133+
attention_mask = attention_mask[None, :, :]
134+
attention_mask = (attention_mask < 0.5).astype("int64")
135+
136+
labels = [-100] * context_length + input_ids[mask_position + 1 :]
137+
138+
inputs = {
139+
"input_ids": input_ids,
140+
"attention_mask": attention_mask,
141+
"labels": labels,
142+
}
143+
return inputs

examples/language_model/chatglm/finetune_generation.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from functools import partial
1818

1919
import paddle
20-
from data import convert_example, read_local_dataset
20+
from data import convert_example, custom_instruction_convert_example, read_local_dataset
2121
from utils import ChatGLMTrainer
2222

2323
from paddlenlp.data import DataCollatorWithPadding
@@ -36,7 +36,7 @@
3636

3737
@dataclass
3838
class DataArgument:
39-
task_path: str = field(default="./data/", metadata={"help": "Path to data"})
39+
task_name_or_path: str = field(default="./data/", metadata={"help": "Path to data"})
4040
src_length: int = field(default=128, metadata={"help": "The max length of source text."})
4141
tgt_length: int = field(default=180, metadata={"help": "The max length of target text."})
4242
num_beams: int = field(default=5, metadata={"help": "The number of beams."})
@@ -113,8 +113,8 @@ def main():
113113
if model_args.lora:
114114
lora_config = LoRAConfig(
115115
target_modules=[".*query_key_value.*"],
116-
r=4,
117-
lora_alpha=8,
116+
r=8,
117+
lora_alpha=16,
118118
merge_weights=True,
119119
enable_lora_list=[[True, False, True]],
120120
tensor_parallel_degree=training_args.tensor_parallel_degree,
@@ -126,9 +126,20 @@ def main():
126126
tokenizer = ChatGLMTokenizer.from_pretrained(model_args.model_name_or_path)
127127

128128
# Load the dataset.
129-
train_ds = load_dataset(read_local_dataset, path=os.path.join(data_args.task_path, "train.json"), lazy=False)
130-
dev_ds = load_dataset(read_local_dataset, path=os.path.join(data_args.task_path, "dev.json"), lazy=False)
131-
trans_func = partial(convert_example, tokenizer=tokenizer, data_args=data_args)
129+
if os.path.exists(os.path.join(data_args.task_name_or_path, "train.json")) and os.path.exists(
130+
os.path.join(data_args.task_name_or_path, "dev.json")
131+
):
132+
train_ds = load_dataset(
133+
read_local_dataset, path=os.path.join(data_args.task_name_or_path, "train.json"), lazy=False
134+
)
135+
dev_ds = load_dataset(
136+
read_local_dataset, path=os.path.join(data_args.task_name_or_path, "dev.json"), lazy=False
137+
)
138+
trans_func = partial(convert_example, tokenizer=tokenizer, data_args=data_args)
139+
else:
140+
train_ds, dev_ds = load_dataset("bellegroup", data_args.task_name_or_path, splits=["train", "dev"])
141+
trans_func = partial(custom_instruction_convert_example, tokenizer=tokenizer, data_args=data_args)
142+
132143
train_ds = train_ds.map(partial(trans_func, is_test=False))
133144
test_ds = dev_ds.map(trans_func)
134145

examples/language_model/chatglm/utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Dict
16+
17+
import numpy as np
1518
import paddle
1619

1720
from paddlenlp.trainer import Trainer
@@ -92,3 +95,11 @@ def prediction_step(
9295
all_labels = None
9396

9497
return (loss, all_preds, all_labels)
98+
99+
def log(self, logs: Dict[str, float], **kwargs) -> None:
100+
if "loss" in logs:
101+
logs["ppl"] = np.exp(logs["loss"])
102+
if "eval_loss" in logs:
103+
logs["eval_ppl"] = np.exp(logs["eval_loss"])
104+
105+
super(ChatGLMTrainer, self).log(logs, **kwargs)

paddlenlp/layers/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def _convert_tensor_parallel(self, lora_state_dict):
735735
lora_state_dict[name] = action(tensor)
736736
return lora_state_dict
737737

738-
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False):
738+
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
739739
assert not os.path.isfile(
740740
save_directory
741741
), f"Saving directory ({save_directory}) should be a directory, not a file"

0 commit comments

Comments
 (0)