Skip to content

Commit 8ffa750

Browse files
authored
FEAT: Decoupled CLIP ratio (DAPO Trick-I) (#285)
* FEAT: add CLIP_higher (DAPO Trick-I) * Change default value for eps_clip_higher * rewrite logic in fuctional (CLIP higher) * 你try to fromatting * try to fromatting * modify formula
1 parent d1b297a commit 8ffa750

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

areal/api/cli_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ class PPOActorConfig(TrainEngineConfig):
258258
eps_clip: float = field(
259259
default=0.2, metadata={"help": "Clipping factor for policy ratio"}
260260
)
261+
eps_clip_higher: Optional[float] = field(
262+
default=None,
263+
metadata={
264+
"help": "Clipping factor (higher value) for policy ratio. Defaults is None. When eps_clip_higher is setted (decouppled), eps_clip will be used as the lower value."
265+
},
266+
)
261267
c_clip: Optional[float] = field(
262268
default=None,
263269
metadata={

areal/engine/ppo/actor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def ppo_update(self, data: TensorDict) -> List[Dict[str, float]]:
226226
grpo_loss_fn,
227227
temperature=self.temperature,
228228
eps_clip=self.config.eps_clip,
229+
eps_clip_higher=self.config.eps_clip_higher,
229230
c_clip=self.config.c_clip,
230231
behav_imp_weight_cap=self.config.behav_imp_weight_cap,
231232
),
@@ -262,6 +263,7 @@ def grpo_loss_fn(
262263
input_data: Dict,
263264
temperature: float,
264265
eps_clip: float,
266+
eps_clip_higher: float | None,
265267
c_clip: float | None,
266268
behav_imp_weight_cap: float | None,
267269
):
@@ -282,6 +284,7 @@ def grpo_loss_fn(
282284
old_logprobs=old_logp,
283285
advantages=advantages,
284286
eps_clip=eps_clip,
287+
eps_clip_higher=eps_clip_higher,
285288
loss_mask=loss_mask,
286289
c_clip=c_clip,
287290
proximal_logprobs=prox_logp,

areal/utils/functional.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def ppo_actor_loss_fn(
126126
advantages: torch.Tensor,
127127
eps_clip: float,
128128
loss_mask: torch.Tensor,
129+
eps_clip_higher: Optional[float] = None,
129130
c_clip: Optional[float] = None,
130131
behav_imp_weight_cap: Optional[float] = None,
131132
) -> Tuple[torch.Tensor, Dict]:
@@ -139,7 +140,13 @@ def ppo_actor_loss_fn(
139140
"""
140141
loss_mask_count = loss_mask.count_nonzero() or 1
141142
ratio = torch.where(loss_mask, torch.exp(logprobs - proximal_logprobs), 0)
142-
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
143+
144+
clipped_ratio = torch.clamp(
145+
ratio,
146+
1.0 - eps_clip,
147+
1.0 + (eps_clip if eps_clip_higher is None else eps_clip_higher),
148+
)
149+
143150
pg_loss1 = -advantages * ratio
144151
pg_loss2 = -advantages * clipped_ratio
145152
clip_mask = pg_loss1.detach() < pg_loss2.detach()

0 commit comments

Comments
 (0)