Skip to content

Commit 3db86ca

Browse files
Add optional normalization (cont.) (#98)
* fix(ppo): optional reward scaling and minibatch advantage whitening * feat(ppo): add optional reward clipping * chore(ppo): add tests, comments * fix(github): rename master to main for build * feat(ppo): add manual reward scaling
1 parent aafcae9 commit 3db86ca

File tree

7 files changed

+46
-15
lines changed

7 files changed

+46
-15
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Build
22

33
on:
44
push:
5-
branches: [ master ]
5+
branches: [ main ]
66
pull_request:
7-
branches: [ master ]
7+
branches: [ main ]
88

99
jobs:
1010
build:

configs/ppo_config.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ method:
3636
cliprange: 0.2 # clip range
3737
cliprange_value: 0.2 # clip range
3838
vf_coef: 2.3 # value term weight
39-
scale_reward: True
40-
clip_reward: 10
39+
scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards
40+
ref_mean: null
41+
ref_std: null # rescale rewards with this deviation
42+
cliprange_reward: 10
4143
gen_kwargs:
4244
max_length: 48 # LM max sample gen length
4345
min_length: 48 # LM min sample gen length

configs/test_config.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ method:
3535
cliprange: 0.2 # clip range
3636
cliprange_value: 0.2 # clip range
3737
vf_coef: 1.0 # value term weight
38+
scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards
39+
cliprange_reward: 10
40+
ref_mean: null
41+
ref_std: null
3842
gen_kwargs:
3943
max_length: 48 # LM max sample gen length
4044
min_length: 48 # LM min sample gen length

tests/test_ppo.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
from trlx.data.configs import TRLConfig
33
from trlx.model.nn.ppo_models import GPTHydraHeadWithValueModel
4+
from trlx.utils.modeling import RunningMoments
45
from transformers import AutoTokenizer
56
import torch
67

@@ -44,3 +45,22 @@ def test_forward(self):
4445
logits_diff = torch.sum(unfrozen_logits - frozen_logits).item()
4546
self.assertEqual(hs_diff, 0)
4647
self.assertEqual(logits_diff, 0)
48+
49+
class TestStatistics(unittest.TestCase):
50+
@classmethod
51+
def setUpClass(cls):
52+
cls.m = RunningMoments()
53+
cls.a1 = torch.arange(100, dtype=float)
54+
cls.a2 = torch.ones(100, dtype=float)
55+
cls.a3 = torch.exp(torch.arange(10, dtype=float))
56+
cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float)
57+
58+
def test_running_moments(self):
59+
assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6)
60+
assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6)
61+
assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6)
62+
assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6)
63+
64+
a = torch.hstack((self.a1, self.a2, self.a3, self.a4))
65+
assert torch.isclose(self.m.mean, a.mean(), atol=1e-6)
66+
assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6)

trlx/model/nn/ppo_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,10 @@ class PPOConfig(MethodConfig):
111111
cliprange: float
112112
cliprange_value: float
113113
vf_coef: float
114-
scale_reward: bool
115-
clip_reward: float
114+
scale_reward: str
115+
ref_mean: Optional[float]
116+
ref_std: Optional[float]
117+
cliprange_reward: float
116118
gen_kwargs: dict
117119

118120
def get_advantages_and_returns(

trlx/orchestrator/ppo_orchestrator.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(
4747
self.rl_model.metric_fn = metric_fn
4848

4949
self.running = RunningMoments()
50-
self.ref_mean = None
51-
self.ref_std = None
50+
self.ref_mean = self.rl_model.config.method.ref_mean
51+
self.ref_std = self.rl_model.config.method.ref_std
5252

5353
def score(self, samples):
5454
"""
@@ -84,19 +84,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
8484
scores = torch.as_tensor(self.score(texts), device=samples.device)
8585
stats["exp_score_time"] = time() - exp_score_time
8686

87+
# store statistics of the initial rollout as reference
8788
if self.ref_mean is None:
8889
self.ref_mean, self.ref_std = scores.mean(), scores.std()
8990
all_scores_mean, all_scores_std = self.running.update(scores)
90-
9191
stats["exp_scores_mean"] = all_scores_mean
9292
stats["exp_scores_std"] = all_scores_std
9393
stats["running_mean"] = self.running.mean
9494
stats["running_std"] = self.running.std
9595

96-
if self.rl_model.config.method.scale_reward:
96+
if self.rl_model.config.method.scale_reward == "running":
9797
scores /= self.running.std
98+
elif self.rl_model.config.method.scale_reward == "ref":
99+
scores /= self.ref_std
98100

99-
clip_reward = self.rl_model.config.method.clip_reward
101+
clip_reward = self.rl_model.config.method.cliprange_reward
100102
if clip_reward:
101103
scores = torch.clip(scores, -clip_reward, clip_reward)
102104

trlx/utils/modeling.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,13 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
9191
delta = xs_mean - self.mean
9292
tot_count = self.count + xs_count
9393

94-
m_a = self.var * self.count
95-
m_b = xs_var * xs_count
96-
m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count
94+
new_sum = xs_var * xs_count
95+
# correct old_sum deviation accounting for the new mean
96+
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
97+
tot_sum = old_sum + new_sum
9798

9899
self.mean += delta * xs_count / tot_count
99-
self.var = m_2 / tot_count
100+
self.var = tot_sum / tot_count
100101
self.std = (self.var * tot_count / (tot_count - 1)).sqrt()
101102
self.count = tot_count
102103

0 commit comments

Comments
 (0)