1616from typing import Optional
1717
1818from datasets import load_dataset
19+ from utils import CustomTrainer
1920
2021from paddlenlp .data import DataCollatorForSeq2Seq
2122from paddlenlp .peft import LoRAConfig , LoRAModel
22- from paddlenlp .trainer import PdArgumentParser , Trainer , TrainingArguments
23+ from paddlenlp .trainer import PdArgumentParser , TrainingArguments
2324from paddlenlp .transformers import AutoModelForCausalLM , AutoTokenizer
2425
2526"""
2627单卡
27- python train_nl2sql.py --model_name_or_path bigscience/bloomz-7b1-mt \
28- --train_file nl2sql/dev.jsonl --validation_file nl2sql/dev.jsonl \
28+ python benchmark.py --model_name_or_path bigscience/bloomz-7b1-mt \
2929 --num_train_epochs 1 --per_device_train_batch_size 4 \
30- --evaluation_strategy epoch --save_strategy epoch \
31- --fp16 --fp16_opt_level O2 \
30+ --evaluation_strategy no --save_strategy no \
31+ --fp16 --fp16_opt_level O2 --lora \
3232 --logging_steps 50 --output_dir outputs
3333
34- 多卡 mp
35- python train_nl2sql.py --model_name_or_path bigscience/bloomz-7b1-mt \
36- --train_file nl2sql/dev.jsonl --validation_file nl2sql/dev.jsonl \
37- --num_train_epochs 1 --per_device_train_batch_size 16 \
38- --evaluation_strategy epoch --save_strategy epoch \
39- --fp16 --fp16_opt_level O2 \
34+ 多卡mp
35+ python -m paddle.distributed.launch --gpus "0,1,2,3" benchmark.py --model_name_or_path bigscience/bloomz-7b1-mt \
36+ --num_train_epochs 1 --per_device_train_batch_size 8 \
37+ --evaluation_strategy no --save_strategy no \
38+ --fp16 --fp16_opt_level O2 --tensor_parallel_degree 4 \
4039 --logging_steps 50 --output_dir outputs
4140
42- 多卡 sharding 3
43- python -m paddle.distributed.launch --gpus "0,1,2,3" train_nl2sql.py --model_name_or_path bigscience/bloomz-7b1-mt \
44- --train_file nl2sql/dev.jsonl --validation_file nl2sql/dev.jsonl \
41+ 多卡sharding 3
42+ python -m paddle.distributed.launch --gpus "0,1,2,3" benchmark.py --model_name_or_path bigscience/bloomz-7b1-mt \
4543 --num_train_epochs 1 --per_device_train_batch_size 4 \
46- --evaluation_strategy epoch --save_strategy epoch \
44+ --evaluation_strategy no --save_strategy no \
4745 --fp16 --fp16_opt_level O2 \
4846 --sharding "stage3" --sharding_parallel_degree 4 \
4947 --logging_steps 50 --output_dir outputs
@@ -60,19 +58,9 @@ class ModelArguments:
6058 lora : Optional [bool ] = field (default = False , metadata = {"help" : "whether to use LoRA" })
6159
6260
63- @dataclass
64- class DataTrainingArguments :
65- """
66- Arguments pertaining to what data we are going to input our model for training and eval.
67- """
68-
69- train_file : str = field (default = None , metadata = {"help" : "The input training data file (a text file)." })
70- validation_file : str = field (default = None , metadata = {"help" : "The input evaluation data file (a text file).e)." })
71-
72-
7361def main ():
74- parser = PdArgumentParser ((ModelArguments , DataTrainingArguments , TrainingArguments ))
75- model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
62+ parser = PdArgumentParser ((ModelArguments , TrainingArguments ))
63+ model_args , training_args = parser .parse_args_into_dataclasses ()
7664
7765 # Set the dtype for loading model
7866 dtype = None
@@ -83,10 +71,13 @@ def main():
8371 dtype = "bfloat16"
8472
8573 tokenizer = AutoTokenizer .from_pretrained (model_args .model_name_or_path )
74+ if "llama" in model_args .model_name_or_path :
75+ tokenizer .pad_token = tokenizer .unk_token
8676 model = AutoModelForCausalLM .from_pretrained (
8777 model_args .model_name_or_path ,
8878 load_state_as_np = True ,
8979 low_cpu_mem_usage = True ,
80+ # use_flash_attention=True,
9081 dtype = dtype ,
9182 tensor_parallel_degree = training_args .tensor_parallel_degree ,
9283 tensor_parallel_rank = training_args .tensor_parallel_rank ,
@@ -105,9 +96,9 @@ def main():
10596 model .mark_only_lora_as_trainable ()
10697 model .print_trainable_parameters ()
10798
108- def preprocess_function (example , max_src_length = 512 , max_tgt_length = 256 ):
109- inputs = example ["src" ][ 0 ]
110- targets = example ["tgt" ][ 0 ]
99+ def preprocess_function (example , max_src_length = 512 , max_tgt_length = 512 ):
100+ inputs = example ["instruction" ]
101+ targets = example ["output" ]
111102 model_inputs = tokenizer (inputs , max_length = max_src_length , truncation = True , return_attention_mask = False )
112103 labels = tokenizer (targets , max_length = max_tgt_length , truncation = True , return_attention_mask = False )
113104 labels_input_ids = labels ["input_ids" ] + [tokenizer .eos_token_id ]
@@ -116,17 +107,25 @@ def preprocess_function(example, max_src_length=512, max_tgt_length=256):
116107
117108 return model_inputs
118109
119- dataset = load_dataset ("json" , data_files = {"train" : data_args .train_file , "dev" : data_args .validation_file })
120- dataset = dataset .map (lambda example : preprocess_function (example ))
110+ dataset = load_dataset ("Chinese-Vicuna/guanaco_belle_merge_v1.0" )
111+ # select first 10k examples for benchmarking
112+ dataset = dataset ["train" ].select (range (10000 ))
113+ dataset = dataset .map (
114+ lambda example : preprocess_function (example ), remove_columns = ["instruction" , "input" , "output" ]
115+ )
116+ total_effective_tokens = sum ([len (i ["input_ids" ]) for i in dataset ]) * training_args .num_train_epochs
121117
122- trainer = Trainer (
118+ trainer = CustomTrainer (
123119 model = model ,
124- train_dataset = dataset ["train" ],
125- eval_dataset = dataset ["dev" ],
120+ train_dataset = dataset ,
126121 args = training_args ,
127122 data_collator = DataCollatorForSeq2Seq (return_tensors = "pd" , tokenizer = tokenizer ),
128123 )
129- trainer .train ()
124+ train_metrics = trainer .train ()
125+ tokens_per_second = trainer .total_observed_tokens / train_metrics .metrics ["train_runtime" ]
126+ effective_tokens_per_second = total_effective_tokens / train_metrics .metrics ["train_runtime" ]
127+ print (f"Tokens per second: { tokens_per_second :.2f} " )
128+ print (f"Effective Tokens per second: { effective_tokens_per_second :.2f} " )
130129
131130
132131if __name__ == "__main__" :
0 commit comments