|
22 | 22 | from paddlenlp.data import DataCollatorForSeq2Seq |
23 | 23 | from paddlenlp.peft import LoRAConfig, LoRAModel |
24 | 24 | from paddlenlp.trainer import PdArgumentParser, TrainingArguments |
25 | | -from paddlenlp.transformers import ( |
26 | | - AutoModelForCausalLM, |
27 | | - AutoTokenizer, |
28 | | - ChatGLMForConditionalGeneration, |
29 | | -) |
| 25 | +from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer |
30 | 26 |
|
31 | 27 | """ |
32 | 28 | 单卡 |
@@ -83,29 +79,16 @@ def main(): |
83 | 79 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) |
84 | 80 | if "llama" in model_args.model_name_or_path: |
85 | 81 | tokenizer.pad_token = tokenizer.unk_token |
86 | | - if "chatglm" in model_args.model_name_or_path: |
87 | | - model = ChatGLMForConditionalGeneration.from_pretrained( |
88 | | - model_args.model_name_or_path, |
89 | | - load_state_as_np=True, |
90 | | - low_cpu_mem_usage=True, |
91 | | - # use_flash_attention=True, |
92 | | - dtype=dtype, |
93 | | - tensor_parallel_degree=training_args.tensor_parallel_degree, |
94 | | - tensor_parallel_rank=training_args.tensor_parallel_rank, |
95 | | - recompute=training_args.recompute, |
96 | | - ) |
97 | | - |
98 | | - else: |
99 | | - model = AutoModelForCausalLM.from_pretrained( |
100 | | - model_args.model_name_or_path, |
101 | | - load_state_as_np=True, |
102 | | - low_cpu_mem_usage=True, |
103 | | - # use_flash_attention=True, |
104 | | - dtype=dtype, |
105 | | - tensor_parallel_degree=training_args.tensor_parallel_degree, |
106 | | - tensor_parallel_rank=training_args.tensor_parallel_rank, |
107 | | - use_recompute=training_args.recompute, |
108 | | - ) |
| 82 | + model = AutoModelForCausalLM.from_pretrained( |
| 83 | + model_args.model_name_or_path, |
| 84 | + load_state_as_np=True, |
| 85 | + low_cpu_mem_usage=True, |
| 86 | + # use_flash_attention=True, |
| 87 | + dtype=dtype, |
| 88 | + tensor_parallel_degree=training_args.tensor_parallel_degree, |
| 89 | + tensor_parallel_rank=training_args.tensor_parallel_rank, |
| 90 | + use_recompute=training_args.recompute, |
| 91 | + ) |
109 | 92 |
|
110 | 93 | if model_args.lora: |
111 | 94 | if "llama" in model_args.model_name_or_path: |
|
0 commit comments