File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -125,11 +125,11 @@ def main():
125125 recompute_granularity = model_args .recompute_granularity ,
126126 use_flash_attention = model_args .use_flash_attention ,
127127 tensor_parallel_output = model_args .tensor_parallel_output ,
128- dpo_config = dpo_config ,
129128 )
130129
131130 if training_args .pipeline_parallel_degree > 1 :
132131 model_class = AutoModelForCausalLMPipe
132+ model_kwargs ["dpo_config" ] = dpo_config
133133 else :
134134 model_class = AutoModelForCausalLM
135135 if not training_args .autotuner_benchmark or model_args .weight_quantize_algo is not None :
@@ -148,7 +148,8 @@ def main():
148148 ref_model = model_class .from_config (config , dtype = dtype )
149149 else :
150150 ref_model = None
151- model .config .dpo_config = None
151+ if training_args .pipeline_parallel_degree > 1 :
152+ model .config .dpo_config = None
152153 if model_args .flash_mask and not model .config .use_flash_attention :
153154 logger .warning ("`flash_mask` must use with zero padding and flash attention." )
154155 model .config .use_flash_attention = True
You can’t perform that action at this time.
0 commit comments