1717from functools import partial
1818
1919import paddle
20- from data import convert_example , read_local_dataset
20+ from data import convert_example , custom_instruction_convert_example , read_local_dataset
2121from utils import ChatGLMTrainer
2222
2323from paddlenlp .data import DataCollatorWithPadding
3636
3737@dataclass
3838class DataArgument :
39- task_path : str = field (default = "./data/" , metadata = {"help" : "Path to data" })
39+ task_name_or_path : str = field (default = "./data/" , metadata = {"help" : "Path to data" })
4040 src_length : int = field (default = 128 , metadata = {"help" : "The max length of source text." })
4141 tgt_length : int = field (default = 180 , metadata = {"help" : "The max length of target text." })
4242 num_beams : int = field (default = 5 , metadata = {"help" : "The number of beams." })
@@ -113,8 +113,8 @@ def main():
113113 if model_args .lora :
114114 lora_config = LoRAConfig (
115115 target_modules = [".*query_key_value.*" ],
116- r = 4 ,
117- lora_alpha = 8 ,
116+ r = 8 ,
117+ lora_alpha = 16 ,
118118 merge_weights = True ,
119119 enable_lora_list = [[True , False , True ]],
120120 tensor_parallel_degree = training_args .tensor_parallel_degree ,
@@ -126,9 +126,20 @@ def main():
126126 tokenizer = ChatGLMTokenizer .from_pretrained (model_args .model_name_or_path )
127127
128128 # Load the dataset.
129- train_ds = load_dataset (read_local_dataset , path = os .path .join (data_args .task_path , "train.json" ), lazy = False )
130- dev_ds = load_dataset (read_local_dataset , path = os .path .join (data_args .task_path , "dev.json" ), lazy = False )
131- trans_func = partial (convert_example , tokenizer = tokenizer , data_args = data_args )
129+ if os .path .exists (os .path .join (data_args .task_name_or_path , "train.json" )) and os .path .exists (
130+ os .path .join (data_args .task_name_or_path , "dev.json" )
131+ ):
132+ train_ds = load_dataset (
133+ read_local_dataset , path = os .path .join (data_args .task_name_or_path , "train.json" ), lazy = False
134+ )
135+ dev_ds = load_dataset (
136+ read_local_dataset , path = os .path .join (data_args .task_name_or_path , "dev.json" ), lazy = False
137+ )
138+ trans_func = partial (convert_example , tokenizer = tokenizer , data_args = data_args )
139+ else :
140+ train_ds , dev_ds = load_dataset ("bellegroup" , data_args .task_name_or_path , splits = ["train" , "dev" ])
141+ trans_func = partial (custom_instruction_convert_example , tokenizer = tokenizer , data_args = data_args )
142+
132143 train_ds = train_ds .map (partial (trans_func , is_test = False ))
133144 test_ds = dev_ds .map (trans_func )
134145
0 commit comments