Skip to content

Commit 94c4866

Browse files
committed
add dpo pp
1 parent cef3165 commit 94c4866

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

llm/alignment/dpo/run_dpo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,11 @@ def main():
125125
recompute_granularity=model_args.recompute_granularity,
126126
use_flash_attention=model_args.use_flash_attention,
127127
tensor_parallel_output=model_args.tensor_parallel_output,
128-
dpo_config=dpo_config,
129128
)
130129

131130
if training_args.pipeline_parallel_degree > 1:
132131
model_class = AutoModelForCausalLMPipe
132+
model_kwargs["dpo_config"] = dpo_config
133133
else:
134134
model_class = AutoModelForCausalLM
135135
if not training_args.autotuner_benchmark or model_args.weight_quantize_algo is not None:
@@ -148,7 +148,8 @@ def main():
148148
ref_model = model_class.from_config(config, dtype=dtype)
149149
else:
150150
ref_model = None
151-
model.config.dpo_config = None
151+
if training_args.pipeline_parallel_degree > 1:
152+
model.config.dpo_config = None
152153
if model_args.flash_mask and not model.config.use_flash_attention:
153154
logger.warning("`flash_mask` must use with zero padding and flash attention.")
154155
model.config.use_flash_attention = True

0 commit comments

Comments
 (0)