@@ -137,6 +137,43 @@ def main():
137137 weight_double_quant_block_size = model_args .weight_double_quant_block_size ,
138138 )
139139
140+ model_config = AutoConfig .from_pretrained (
141+ model_args .model_name_or_path ,
142+ tensor_parallel_output = training_args .tensor_parallel_output ,
143+ tensor_parallel_degree = training_args .tensor_parallel_degree ,
144+ tensor_parallel_rank = training_args .tensor_parallel_rank ,
145+ dtype = dtype ,
146+ from_aistudio = model_args .from_aistudio ,
147+ quantization_config = quantization_config ,
148+ )
149+ if hasattr (model_config , "use_flash_attention" ):
150+ model_config .use_flash_attention = model_args .use_flash_attention
151+
152+ model_config .use_fused_rms_norm = model_args .use_fused_rms_norm
153+ model_config .fuse_attention_qkv = model_args .fuse_attention_qkv
154+ model_config .fuse_attention_ffn = model_args .fuse_attention_ffn
155+ model_config .recompute_granularity = model_args .recompute_granularity
156+ model_config .virtual_pp_degree = model_args .virtual_pp_degree
157+ model_config .sequence_parallel = model_args .sequence_parallel
158+ model_config .fuse_sequence_parallel_allreduce = model_args .fuse_sequence_parallel_allreduce
159+ model_config .use_fused_rope = model_args .use_fused_rope
160+
161+ model_config .no_recompute_layers = model_args .no_recompute_layers
162+ model_config .pp_recompute_interval = model_args .pp_recompute_interval
163+ model_config .recompute_use_reentrant = model_args .recompute_use_reentrant
164+ model_config .use_recompute = training_args .recompute
165+
166+ model_config .tensor_parallel_degree = training_args .tensor_parallel_degree
167+ model_config .tensor_parallel_rank = training_args .tensor_parallel_rank
168+
169+ # Config for model using dropout, such as GPT.
170+ model_config .hidden_dropout_prob = model_args .hidden_dropout_prob
171+ model_config .attention_probs_dropout_prob = model_args .attention_probs_dropout_prob
172+
173+ model_config .sep_parallel_degree = training_args .sep_parallel_degree
174+ model_config .tensor_parallel_output = True
175+ model_config .seq_length = data_args .max_length
176+
140177 if training_args .pipeline_parallel_degree > 1 :
141178 if data_args .eval_with_do_generation and training_args .do_eval :
142179 raise ValueError ("Plese set eval_with_do_generation to false in pipeline parallel mode." )
@@ -145,63 +182,13 @@ def main():
145182 if not training_args .autotuner_benchmark :
146183 model = AutoModelForCausalLMPipe .from_pretrained (
147184 model_args .model_name_or_path ,
148- tensor_parallel_output = training_args .tensor_parallel_output ,
149- tensor_parallel_degree = training_args .tensor_parallel_degree ,
150- tensor_parallel_rank = training_args .tensor_parallel_rank ,
151- use_flash_attention = model_args .use_flash_attention ,
152- dtype = dtype ,
185+ config = model_config ,
153186 from_aistudio = model_args .from_aistudio ,
154- quantization_config = quantization_config ,
155187 )
156188 else :
157189 # NOTE(gongenlei): new add autotuner_benchmark
158- model_config = AutoConfig .from_pretrained (
159- model_args .model_name_or_path ,
160- tensor_parallel_output = training_args .tensor_parallel_output ,
161- tensor_parallel_degree = training_args .tensor_parallel_degree ,
162- tensor_parallel_rank = training_args .tensor_parallel_rank ,
163- dtype = dtype ,
164- from_aistudio = model_args .from_aistudio ,
165- quantization_config = quantization_config ,
166- )
167190 model = AutoModelForCausalLMPipe .from_config (model_config , dtype = dtype )
168191 else :
169- model_config = AutoConfig .from_pretrained (
170- model_args .model_name_or_path ,
171- tensor_parallel_output = training_args .tensor_parallel_output ,
172- tensor_parallel_degree = training_args .tensor_parallel_degree ,
173- tensor_parallel_rank = training_args .tensor_parallel_rank ,
174- dtype = dtype ,
175- from_aistudio = model_args .from_aistudio ,
176- quantization_config = quantization_config ,
177- )
178- if hasattr (model_config , "use_flash_attention" ):
179- model_config .use_flash_attention = model_args .use_flash_attention
180-
181- model_config .use_fused_rms_norm = model_args .use_fused_rms_norm
182- model_config .fuse_attention_qkv = model_args .fuse_attention_qkv
183- model_config .fuse_attention_ffn = model_args .fuse_attention_ffn
184- model_config .recompute_granularity = model_args .recompute_granularity
185- model_config .virtual_pp_degree = model_args .virtual_pp_degree
186- model_config .sequence_parallel = model_args .sequence_parallel
187- model_config .fuse_sequence_parallel_allreduce = model_args .fuse_sequence_parallel_allreduce
188- model_config .use_fused_rope = model_args .use_fused_rope
189-
190- model_config .no_recompute_layers = model_args .no_recompute_layers
191- model_config .pp_recompute_interval = model_args .pp_recompute_interval
192- model_config .recompute_use_reentrant = model_args .recompute_use_reentrant
193- model_config .use_recompute = training_args .recompute
194-
195- model_config .tensor_parallel_degree = training_args .tensor_parallel_degree
196- model_config .tensor_parallel_rank = training_args .tensor_parallel_rank
197-
198- # Config for model using dropout, such as GPT.
199- model_config .hidden_dropout_prob = model_args .hidden_dropout_prob
200- model_config .attention_probs_dropout_prob = model_args .attention_probs_dropout_prob
201-
202- model_config .sep_parallel_degree = training_args .sep_parallel_degree
203- model_config .tensor_parallel_output = True
204- model_config .seq_length = data_args .max_length
205192 if not training_args .autotuner_benchmark :
206193 model = AutoModelForCausalLM .from_pretrained (
207194 model_args .model_name_or_path ,
0 commit comments