Skip to content

Commit 79695cc

Browse files
authored
[llm]support dpo pp for qwen & llama (#9695)
* support dpo pp * add dpo pp * add dpo pp
1 parent 691ae01 commit 79695cc

File tree

11 files changed

+312
-11
lines changed

11 files changed

+312
-11
lines changed

llm/alignment/dpo/run_dpo.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
LlamaForCausalLM,
4343
LlamaForCausalLMPipe,
4444
Qwen2ForCausalLM,
45+
Qwen2ForCausalLMPipe,
4546
register_sequence_parallel_allreduce_hooks,
4647
)
4748
from paddlenlp.trl import (
@@ -53,7 +54,7 @@
5354
from paddlenlp.trl.llm_utils import get_lora_target_modules
5455
from paddlenlp.utils.log import logger
5556

56-
flash_mask_support_list = [Qwen2ForCausalLM, LlamaForCausalLM, LlamaForCausalLMPipe]
57+
flash_mask_support_list = [Qwen2ForCausalLM, Qwen2ForCausalLMPipe, LlamaForCausalLM, LlamaForCausalLMPipe]
5758

5859

5960
def main():
@@ -74,7 +75,20 @@ def main():
7475
if dpo_config.loss_type in ["or", "simpo"] and not dpo_config.reference_free:
7576
dpo_config.reference_free = True
7677
logger.warning(f"{dpo_config.loss_type} loss_type only supports reference_free. Set reference_free to True.")
77-
78+
if training_args.pipeline_parallel_degree > 1:
79+
assert (
80+
hasattr(training_args, "pipeline_parallel_config")
81+
and "enable_clear_every_step_cache" in training_args.pipeline_parallel_config
82+
), "Should set '--pipeline_parallel_config enable_clear_every_step_cache' in bash script for pp."
83+
if model_args.sequence_parallel:
84+
if training_args.pipeline_parallel_degree > 1:
85+
assert (
86+
hasattr(training_args, "pipeline_parallel_config")
87+
and "disable_partial_send_recv" in training_args.pipeline_parallel_config
88+
), "Should set '--pipeline_parallel_config disable_partial_send_recv' in bash script for pp with sp."
89+
if training_args.tensor_parallel_degree <= 1:
90+
model_args.sequence_parallel = False
91+
logger.info("Tensor_parallel_degree = 1. Set sequence_parallel to False.")
7892
training_args.print_config(model_args, "Model")
7993
training_args.print_config(data_args, "Data")
8094
training_args.print_config(dpo_config, "DPOConfig")
@@ -112,16 +126,15 @@ def main():
112126
use_flash_attention=model_args.use_flash_attention,
113127
tensor_parallel_output=model_args.tensor_parallel_output,
114128
)
115-
if training_args.pipeline_parallel_degree > 1:
116-
raise ValueError("DPO does not support pipeline parallelism yet.")
129+
117130
if training_args.pipeline_parallel_degree > 1:
118131
model_class = AutoModelForCausalLMPipe
132+
model_kwargs["dpo_config"] = dpo_config
119133
else:
120134
model_class = AutoModelForCausalLM
121135
if not training_args.autotuner_benchmark or model_args.weight_quantize_algo is not None:
122136
model = model_class.from_pretrained(**model_kwargs)
123137
# for DPO save
124-
model.config.dpo_config = None
125138
if not dpo_config.reference_free and not dpo_config.lora:
126139
config = AutoConfig.from_pretrained(**model_kwargs)
127140
ref_model = model_class.from_config(config, dtype=dtype)
@@ -135,7 +148,8 @@ def main():
135148
ref_model = model_class.from_config(config, dtype=dtype)
136149
else:
137150
ref_model = None
138-
151+
if training_args.pipeline_parallel_degree > 1:
152+
model.config.dpo_config = None
139153
if model_args.flash_mask and not model.config.use_flash_attention:
140154
logger.warning("`flash_mask` must use with zero padding and flash attention.")
141155
model.config.use_flash_attention = True
File renamed without changes.
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import copy
15+
import os
16+
17+
import paddle
18+
import paddle.nn as nn
19+
import paddle.nn.functional as F
20+
from paddle.distributed import fleet
21+
from paddle.distributed.fleet.meta_parallel import ParallelCrossEntropy
22+
23+
from paddlenlp.transformers import (
24+
AllGatherVarlenOp,
25+
fused_head_and_loss_fn,
26+
parallel_linear,
27+
parallel_matmul,
28+
sequence_parallel_sparse_mask_labels,
29+
)
30+
from paddlenlp.utils import infohub
31+
32+
33+
class KTOCriterion(nn.Layer):
34+
"""KTO Criterion"""
35+
36+
def __init__(self, config, kto_config=None, ignore_label=0, use_infohub=False):
37+
super(KTOCriterion, self).__init__()
38+
self.config = config
39+
if kto_config is None:
40+
if getattr(self.config, "kto_config", None) is None:
41+
raise ValueError("KTO Criterion requires model_config.kto_config.")
42+
self.kto_config = copy.deepcopy(config.kto_config)
43+
else:
44+
self.kto_config = kto_config
45+
if self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1:
46+
self.logprobs = ParallelCrossEntropy()
47+
else:
48+
self.logprobs = nn.CrossEntropyLoss(reduction="none")
49+
self.use_infohub = use_infohub
50+
self.ignore_label = ignore_label
51+
# allgather kl in criterion
52+
topo = fleet.get_hybrid_communicate_group()._topo
53+
parallel_groups = topo.get_comm_list("pipe")
54+
ranks = []
55+
for group in parallel_groups:
56+
ranks.append(group[-1])
57+
self.comm_group = paddle.distributed.new_group(ranks=ranks)
58+
59+
def _nested_gather(self, tensors):
60+
"""
61+
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
62+
concatenating them to `gathered`
63+
"""
64+
local_rank = -1
65+
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
66+
if env_local_rank != -1 and env_local_rank != local_rank and paddle.distributed.get_world_size() > 1:
67+
local_rank = env_local_rank
68+
if tensors is None:
69+
return
70+
if local_rank != -1:
71+
output_tensors = []
72+
paddle.distributed.all_gather(
73+
output_tensors, paddle.tile(tensors, repeat_times=[1, 1]), group=self.comm_group
74+
)
75+
tensors = paddle.concat(output_tensors, axis=0)
76+
return tensors
77+
78+
def kto_logps(self, logits, response_labels, response_kl_labels, response_indexs):
79+
"""KTO logprobs"""
80+
labels = response_labels + response_kl_labels
81+
if self.config.use_fused_head_and_loss_fn:
82+
hidden_states, weight, bias, transpose_y = logits
83+
elif self.config.use_sparse_head_and_loss_fn:
84+
hidden_states, weight, bias = logits
85+
if self.config.use_sparse_head_and_loss_fn:
86+
if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel:
87+
labels, sparse_tgt_idx = sequence_parallel_sparse_mask_labels(labels, self.ignore_label)
88+
89+
hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx, axis=0)
90+
hidden_states = AllGatherVarlenOp.apply(hidden_states)
91+
else:
92+
labels = labels.flatten()
93+
sparse_tgt_idx = paddle.nonzero(labels != self.ignore_label).flatten()
94+
labels = paddle.take_along_axis(labels, sparse_tgt_idx, axis=0)
95+
96+
hidden_states = hidden_states.reshape([-1, hidden_states.shape[-1]])
97+
hidden_states = paddle.take_along_axis(hidden_states, sparse_tgt_idx.unsqueeze(-1), axis=0)
98+
if self.config.use_fused_head_and_loss_fn:
99+
per_token_logps = -fused_head_and_loss_fn(
100+
hidden_states,
101+
weight,
102+
bias,
103+
labels,
104+
None,
105+
transpose_y,
106+
self.config.vocab_size,
107+
self.config.tensor_parallel_degree,
108+
self.config.tensor_parallel_output,
109+
self.config.fused_linear,
110+
getattr(self.config, "chunk_size", 1024),
111+
return_token_loss=True,
112+
ignore_index=self.ignore_label,
113+
)
114+
elif self.config.use_sparse_head_and_loss_fn:
115+
if bias is None:
116+
logits = parallel_matmul(hidden_states, weight, self.config.tensor_parallel_output)
117+
else:
118+
logits = parallel_linear(
119+
hidden_states,
120+
weight,
121+
bias,
122+
self.config.tensor_parallel_output,
123+
)
124+
logits = logits.astype("float32")
125+
per_token_logps = -self.logprobs(logits, labels)
126+
else:
127+
logits = logits.astype("float32")
128+
if logits.shape[:-1] != labels.shape:
129+
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
130+
# bs, seq
131+
per_token_logps = -self.logprobs(logits, labels.unsqueeze(2)).squeeze(2)
132+
133+
if len(response_indexs.shape) == 3:
134+
response_indexs = response_indexs[0]
135+
if self.config.use_sparse_head_and_loss_fn:
136+
chosen_logps_list = [
137+
(per_token_logps[response_index[1] : response_index[2]]).sum()
138+
for response_index in response_indexs
139+
if response_index[4] == 1
140+
]
141+
rejected_logps_list = [
142+
(per_token_logps[response_index[1] : response_index[2]]).sum()
143+
for response_index in response_indexs
144+
if response_index[4] == 0
145+
]
146+
kl_logps_list = [
147+
(per_token_logps[response_index[2] : response_index[3]]).sum() for response_index in response_indexs
148+
]
149+
else:
150+
chosen_logps_list = [
151+
(per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum()
152+
for response_index in response_indexs
153+
if response_index[4] == 1
154+
]
155+
rejected_logps_list = [
156+
(per_token_logps[response_index[0]][response_index[1] : response_index[2]]).sum()
157+
for response_index in response_indexs
158+
if response_index[4] == 0
159+
]
160+
kl_logps_list = [
161+
(per_token_logps[response_index[0]][response_index[2] : response_index[3]]).sum()
162+
for response_index in response_indexs
163+
]
164+
if len(chosen_logps_list) == 0:
165+
chosen_logps = paddle.zeros([0], dtype="float32")
166+
else:
167+
chosen_logps = paddle.stack(chosen_logps_list, axis=0)
168+
if len(rejected_logps_list) == 0:
169+
rejected_logps = paddle.zeros([0], dtype="float32")
170+
else:
171+
rejected_logps = paddle.stack(rejected_logps_list, axis=0)
172+
kl_logps = paddle.stack(kl_logps_list, axis=0)
173+
return chosen_logps, rejected_logps, kl_logps
174+
175+
def kto_loss(
176+
self,
177+
policy_chosen_logps,
178+
policy_rejected_logps,
179+
policy_kl_logps,
180+
reference_chosen_logps,
181+
reference_rejected_logps,
182+
reference_kl_logps,
183+
):
184+
"""KTO Loss"""
185+
kl = (policy_kl_logps - reference_kl_logps).mean().detach()
186+
kl = self._nested_gather(paddle.tile(kl, repeat_times=[1, 1])).mean().clip(min=0)
187+
if policy_chosen_logps.shape[0] == 0 or reference_chosen_logps.shape[0] == 0:
188+
chosen_losses = paddle.zeros([0])
189+
else:
190+
chosen_logratios = policy_chosen_logps - reference_chosen_logps
191+
chosen_losses = 1 - F.sigmoid(self.kto_config.beta * (chosen_logratios - kl))
192+
if policy_rejected_logps.shape[0] == 0 or reference_rejected_logps.shape[0] == 0:
193+
rejected_losses = paddle.zeros([0])
194+
else:
195+
rejected_logratios = policy_rejected_logps - reference_rejected_logps
196+
rejected_losses = 1 - F.sigmoid(self.kto_config.beta * (kl - rejected_logratios))
197+
losses = paddle.concat(
198+
(
199+
self.kto_config.desirable_weight * chosen_losses,
200+
self.kto_config.undesirable_weight * rejected_losses,
201+
),
202+
0,
203+
)
204+
return losses.mean(), kl
205+
206+
def forward(
207+
self,
208+
logits,
209+
labels,
210+
):
211+
"""Forward"""
212+
(
213+
response_labels,
214+
response_kl_labels,
215+
response_indexs,
216+
reference_chosen_logps,
217+
reference_rejected_logps,
218+
reference_kl_logps,
219+
) = labels
220+
if reference_chosen_logps is None or reference_rejected_logps is None or reference_kl_logps is None:
221+
(
222+
reference_chosen_logps,
223+
reference_rejected_logps,
224+
reference_kl_logps,
225+
) = self.kto_logps(logits, response_labels, response_kl_labels, response_indexs)
226+
if self.use_infohub:
227+
infohub.reference_chosen_logps.append(reference_chosen_logps)
228+
infohub.reference_rejected_logps.append(reference_rejected_logps)
229+
infohub.reference_kl_logps.append(reference_kl_logps)
230+
# pipeline mode requires return loss when self._compute_loss is True
231+
return paddle.zeros([1])
232+
else:
233+
return (
234+
reference_chosen_logps,
235+
reference_rejected_logps,
236+
reference_kl_logps,
237+
)
238+
policy_chosen_logps, policy_rejected_logps, policy_kl_logps = self.kto_logps(
239+
logits, response_labels, response_kl_labels, response_indexs
240+
)
241+
loss, kl = self.kto_loss(
242+
policy_chosen_logps,
243+
policy_rejected_logps,
244+
policy_kl_logps,
245+
reference_chosen_logps,
246+
reference_rejected_logps,
247+
reference_kl_logps,
248+
)
249+
if self.use_infohub:
250+
infohub.policy_chosen_logps.append(policy_chosen_logps.detach())
251+
infohub.policy_rejected_logps.append(policy_rejected_logps.detach())
252+
infohub.policy_kl_logps.append(policy_kl_logps.detach())
253+
infohub.kl.append(kl.detach())
254+
return loss
255+
else:
256+
return (
257+
policy_chosen_logps,
258+
policy_rejected_logps,
259+
policy_kl_logps,
260+
loss,
261+
kl,
262+
)

paddlenlp/transformers/llama/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def __init__(
158158
use_flash_attention_for_generation=False,
159159
use_last_token_for_generation=False,
160160
immediate_clear_past_key_value=False,
161+
dpo_config=None,
161162
**kwargs,
162163
):
163164
self.vocab_size = vocab_size
@@ -195,6 +196,7 @@ def __init__(
195196
self.use_flash_attention_for_generation = use_flash_attention_for_generation
196197
self.use_last_token_for_generation = use_last_token_for_generation
197198
self.immediate_clear_past_key_value = immediate_clear_past_key_value
199+
self.dpo_config = dpo_config
198200

199201
super().__init__(
200202
pad_token_id=pad_token_id,

paddlenlp/transformers/llama/modeling_pp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from paddlenlp.utils.tools import get_env_device
3232

33+
from ..dpo_criterion import DPOCriterion
3334
from .modeling import (
3435
LlamaConfig,
3536
LlamaDecoderLayer,
@@ -423,4 +424,7 @@ def get_hcg():
423424
# PipelinePretrainedModel.__init__(self.super(), config=config)
424425

425426
def get_loss_fn(self, config):
426-
return LlamaPretrainingCriterion(config)
427+
if config.dpo_config is not None:
428+
return DPOCriterion(config, use_infohub=True)
429+
else:
430+
return LlamaPretrainingCriterion(config)

paddlenlp/transformers/qwen/configuration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
long_sequence_strategy_name=None,
5050
long_sequence_init_args=None,
5151
use_long_sequence_strategies=False,
52+
dpo_config=None,
5253
**kwargs,
5354
):
5455
self.vocab_size = vocab_size
@@ -74,6 +75,7 @@ def __init__(
7475
self.long_sequence_strategy_name = long_sequence_strategy_name
7576
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
7677
self.use_long_sequence_strategies = use_long_sequence_strategies
78+
self.dpo_config = dpo_config
7779

7880
super().__init__(
7981
pad_token_id=pad_token_id,

0 commit comments

Comments
 (0)