Skip to content

Commit dc0e060

Browse files
Fix context overflow (#131)
* feat(ppo_models): context truncating generation * feat(base_model): segragate prompts and responses in logging * fix(ppo_model): truncate left padded tokens * revert(ppo_models): remove context overflowing generate() * feat(configs): add max_new_tokens * fix(pipeline): truncate prompts * chore(base_model): remove whole samples from logging * chore(configs): update the rest of configs * fix(configs): update program synthesis config * fix(base_model): prompts sizes * revert(config): emulate old ppo_sentiments behavior * fix(ppo): unequal generation lengths * chore(ppo): put indexing on cpu * revert(configl): old ilql_sentiments behavior * fix(ppo): unzero clipfrac * merge(configs): delete old options * revert(config): old ppo_sentiment behavior * refactor(ppo_orchestrator): remove unused ref_logprobs * fix(base_model): pin rewards to single precision * refactor(ppo): rename padding percentage * feat(wandb): add git branch name to tags * refactor(wandb): logging name hierarchy * revert(wandb): merge tags into a single string
1 parent 247eb8f commit dc0e060

File tree

16 files changed

+159
-100
lines changed

16 files changed

+159
-100
lines changed

configs/ilql_config.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ method:
4040
awac_scale: 1
4141
alpha: 0.001
4242
steps_for_target_q_sync: 5
43-
betas: [4]
4443
two_qs: true
44+
gen_kwargs:
45+
max_new_tokens: 56
46+
top_k: 20
47+
beta: 4
48+
temperature: 1.0

configs/ppo_config.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
train:
2-
seq_length: 48
2+
seq_length: 1024
33
epochs: 100
44
total_steps: 10000
55
batch_size: 128
@@ -48,8 +48,7 @@ method:
4848
ref_std: null
4949
cliprange_reward: 10
5050
gen_kwargs:
51-
max_length: 48
52-
min_length: 48
53-
top_k: 0.0
51+
max_new_tokens: 40
52+
top_k: 0
5453
top_p: 1.0
5554
do_sample: True

configs/ppo_gptj.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ method:
4848
ref_std: null
4949
cliprange_reward: 10
5050
gen_kwargs:
51-
max_length: 48
52-
min_length: 48
51+
max_new_tokens: 48
5352
top_k: 0.0
5453
top_p: 0.7
5554
do_sample: True

examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ method:
4848
ref_mean: null
4949
ref_std: null
5050
gen_kwargs:
51-
max_length: 256
52-
min_length: 48
53-
top_k: 0.0
51+
max_new_tokens: 256
52+
top_k: 0
5453
top_p: 0.7
5554
do_sample: True
5655
temperature: 0.5

examples/randomwalks/configs/ilql_randomwalks.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ method:
4040
awac_scale: 1
4141
alpha: 0.1
4242
steps_for_target_q_sync: 5
43-
betas: [100]
4443
two_qs: true
44+
gen_kwargs:
45+
max_new_tokens: 9
46+
top_k: 1
47+
beta: 100
48+
temperature: 1.0

examples/randomwalks/configs/ppo_randomwalks.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ method:
4848
ref_std: null
4949
cliprange_reward: 1
5050
gen_kwargs:
51-
max_length: 10
52-
min_length: 2
51+
max_new_tokens: 9
5352
top_k: 0.0
5453
top_p: 1.0
5554
do_sample: True

trlx/model/accelerate_base_model.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def evaluate(self):
147147
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
148148
stats = {}
149149
all_samples = []
150+
prompts_sizes = []
150151
generate_time = time()
151152
for prompts in self.eval_dataloader:
152153
if isinstance(prompts, torch.Tensor):
@@ -165,34 +166,54 @@ def evaluate(self):
165166
value=pad_token,
166167
)
167168
)
168-
stats["generate_time"] = time() - generate_time
169+
sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(
170+
len(prompts.input_ids)
171+
)
172+
prompts_sizes.append(sizes.to(samples.device))
173+
174+
stats["time/generate"] = time() - generate_time
169175

170176
samples = self.accelerator.gather(torch.vstack(all_samples))
177+
prompts_sizes = self.accelerator.gather(torch.hstack(prompts_sizes))
171178

172179
if self.accelerator.is_main_process:
173180
if self.tokenizer:
174-
samples = self.tokenizer.batch_decode(samples, skip_special_tokens=True)
181+
str_samples = self.tokenizer.batch_decode(
182+
samples, skip_special_tokens=True
183+
)
184+
185+
prompts, responses = [], []
186+
for sample, prompt_size in zip(samples, prompts_sizes):
187+
prompts.append(sample[:prompt_size])
188+
responses.append(sample[prompt_size:])
189+
190+
str_prompts = self.tokenizer.batch_decode(
191+
prompts, skip_special_tokens=True
192+
)
193+
str_responses = self.tokenizer.batch_decode(
194+
responses, skip_special_tokens=True
195+
)
175196

176-
if isinstance(samples[0], str):
177-
columns_data = [samples]
197+
if isinstance(str_samples[0], str):
198+
columns_data = [str_prompts, str_responses]
178199
else:
179200
columns_data = [samples.tolist()]
180-
columns = ["samples"]
201+
columns = ["prompt", "response"]
181202

182203
# in online setting, compute the reward for validation
183204
if self.reward_fn:
184-
rewards = torch.as_tensor(self.reward_fn(samples), dtype=torch.float)
205+
rewards = torch.tensor(self.reward_fn(str_samples), dtype=torch.float)
185206
mean_reward = rewards.mean()
186207
columns.append("reward")
187208
columns_data.append(rewards)
188-
stats["mean_reward"] = mean_reward
209+
stats["reward/mean"] = mean_reward
189210
print(f"{mean_reward=}")
190211

191212
# additionally log any other metrics
192213
if self.metric_fn:
193214
metric_time = time()
194-
metrics = self.metric_fn(samples)
195-
stats["metric_time"] = time() - metric_time
215+
metrics = self.metric_fn(str_samples)
216+
stats["time/metric"] = time() - metric_time
196217

197218
mean_metrics = {
198219
f"metrics/{k}": torch.as_tensor(xs).mean(-1)
@@ -258,8 +279,8 @@ def learn(self):
258279
if self.iter_count % self.config.train.checkpoint_interval == 0:
259280
self.save()
260281

261-
stats["forward_time"] = forward_time
262-
stats["backward_time"] = backward_time
282+
stats["time/forward"] = forward_time
283+
stats["time/backward"] = backward_time
263284

264285
if self.iter_count % self.config.train.eval_interval == 0:
265286
results = self.evaluate()

trlx/model/accelerate_ilql_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ def __init__(
3232

3333
self.ilql: ILQLConfig = cast(ILQLConfig, config.method)
3434

35+
self.generate_kwargs = dict(
36+
config.method.gen_kwargs,
37+
max_length=self.max_length,
38+
logit_mask=self.logit_mask,
39+
eos_token_id=self.tokenizer.eos_token_id if self.tokenizer else 0,
40+
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0,
41+
)
42+
3543
def get_arch(self, config):
3644
return CausalLMWithValueHeads(
3745
config.model.model_path,
@@ -87,11 +95,3 @@ def prepare_learning(self):
8795
self.n_updates_per_batch = 1
8896
self.total_steps = self.config.train.epochs * len(train_dataloader)
8997
self.total_steps = min(self.total_steps, self.config.train.total_steps)
90-
91-
self.generate_kwargs = {
92-
"beta": self.config.method.betas[0],
93-
"max_length": self.max_length,
94-
"logit_mask": self.logit_mask,
95-
"eos_token_id": self.tokenizer.eos_token_id if self.tokenizer else 0,
96-
"pad_token_id": self.tokenizer.pad_token_id if self.tokenizer else 0,
97-
}

trlx/model/accelerate_ppo_model.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def get_model_inputs(
6262
query_tensors: TensorType["batch_size", "query_size"],
6363
response_tensors: TensorType["batch_size", "response_size"],
6464
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
65-
tokens = torch.cat((query_tensors, response_tensors), dim=1)
65+
tokens = torch.cat((query_tensors, response_tensors), dim=1)[
66+
:, -self.max_length :
67+
]
6668
attention_mask = (
6769
tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device)
6870
)
@@ -79,23 +81,30 @@ def loss(self, batch: PPORLBatch):
7981
old_values = batch.values.to(self.accelerator.device)
8082
old_rewards = batch.rewards.to(self.accelerator.device)
8183

82-
response_length = response_tensors.shape[-1]
84+
response_length = old_rewards.shape[1]
85+
8386
advantages, returns = self.config.method.get_advantages_and_returns(
8487
old_values, old_rewards, response_length
8588
)
8689

8790
tokens, attention_mask, position_ids = self.get_model_inputs(
8891
query_tensors, response_tensors
8992
)
93+
9094
logits, *_, values_pred = self.model(
9195
tokens, attention_mask=attention_mask, position_ids=position_ids
9296
)
97+
values_pred = values_pred[:, :-1]
9398
logprobs = logprobs_from_logits(logits[:, :-1, :], tokens[:, 1:])
99+
attention_mask = attention_mask[:, :-1]
100+
94101
# Only the response part of the values/logprobs is needed
102+
start = query_tensors.shape[1] - 1
103+
end = start + response_length
95104
logprobs, values_pred, mask = (
96-
logprobs[:, -response_length:],
97-
values_pred[:, -response_length:],
98-
attention_mask[:, -response_length:],
105+
logprobs[:, start:end],
106+
values_pred[:, start:end],
107+
attention_mask[:, start:end],
99108
)
100109

101110
loss, stats = self.config.method.loss(

trlx/model/nn/ilql_models.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ class ILQLConfig(MethodConfig):
4242
awac_scale: float
4343
alpha: float
4444
steps_for_target_q_sync: float
45-
betas: Sequence[float]
4645
two_qs: bool
46+
gen_kwargs: dict
4747

4848
def heads(self, hidden_size: int, vocab_size: int):
4949
return ILQLHeads(self, hidden_size, vocab_size)
@@ -140,7 +140,6 @@ def forward(
140140
states_ixs: torch.Tensor = None,
141141
actions_ixs: torch.Tensor = None,
142142
):
143-
144143
if states_ixs is not None:
145144
states_hs = hs.gather(
146145
dim=1, index=states_ixs.unsqueeze(-1).repeat(1, 1, hs.shape[-1])
@@ -260,7 +259,8 @@ def generate(
260259
position_ids=None,
261260
past_key_values=None,
262261
beta=1,
263-
max_length=32,
262+
max_new_tokens=32,
263+
max_length=1024,
264264
temperature=1,
265265
top_k=20,
266266
logit_mask=None,
@@ -278,13 +278,12 @@ def generate(
278278
position_ids.masked_fill_(attention_mask.eq(0), 0)
279279

280280
samples = input_ids.clone()
281-
tensors = defaultdict(list)
282-
n_new_tokens = max_length - input_ids.shape[1]
281+
max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1])
283282

284283
finished = torch.zeros(
285284
input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device
286285
)
287-
for _ in range(n_new_tokens):
286+
for _ in range(max_new_tokens):
288287
out = self.forward(
289288
input_ids=input_ids,
290289
attention_mask=attention_mask,

0 commit comments

Comments
 (0)