Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dc6acb7
feat(ppo_models): context truncating generation
maxreciprocate Dec 8, 2022
ed4f6f3
feat(base_model): segragate prompts and responses in logging
maxreciprocate Dec 8, 2022
4674e3e
fix(ppo_model): truncate left padded tokens
maxreciprocate Dec 8, 2022
3b886df
revert(ppo_models): remove context overflowing generate()
maxreciprocate Dec 9, 2022
a8b6eaf
feat(configs): add max_new_tokens
maxreciprocate Dec 9, 2022
31e5bb5
fix(pipeline): truncate prompts
maxreciprocate Dec 9, 2022
59176ac
chore(base_model): remove whole samples from logging
maxreciprocate Dec 9, 2022
48899e0
chore(configs): update the rest of configs
maxreciprocate Dec 9, 2022
27bd892
fix(configs): update program synthesis config
maxreciprocate Dec 12, 2022
d6be611
fix(base_model): prompts sizes
maxreciprocate Dec 12, 2022
74e758a
revert(config): emulate old ppo_sentiments behavior
maxreciprocate Dec 12, 2022
7366ab6
fix(ppo): unequal generation lengths
maxreciprocate Dec 14, 2022
1b4d5db
chore(ppo): put indexing on cpu
maxreciprocate Dec 14, 2022
dd8b21a
revert(configl): old ilql_sentiments behavior
maxreciprocate Dec 14, 2022
d1a9c38
fix(ppo): unzero clipfrac
maxreciprocate Dec 14, 2022
97bdd72
Merge branch 'main' into fix-context-overflow
maxreciprocate Dec 14, 2022
f394897
merge(configs): delete old options
maxreciprocate Dec 14, 2022
a960ada
revert(config): old ppo_sentiment behavior
maxreciprocate Dec 14, 2022
a290dd8
refactor(ppo_orchestrator): remove unused ref_logprobs
maxreciprocate Dec 14, 2022
441dd23
fix(base_model): pin rewards to single precision
maxreciprocate Dec 15, 2022
330dde3
refactor(ppo): rename padding percentage
maxreciprocate Dec 15, 2022
6df0195
feat(wandb): add git branch name to tags
maxreciprocate Dec 15, 2022
005b348
refactor(wandb): logging name hierarchy
maxreciprocate Dec 15, 2022
cacebc3
revert(wandb): merge tags into a single string
maxreciprocate Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ model:
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training

train:
seq_length: 48 # Size of LM context
seq_length: 1024 # Size of LM context
epochs: 100 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size
Expand Down
15 changes: 11 additions & 4 deletions trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,30 @@ def loss(self, batch: PPORLBatch):
old_values = batch.values.to(self.accelerator.device)
old_rewards = batch.rewards.to(self.accelerator.device)

response_length = response_tensors.shape[-1]
response_length = old_rewards.shape[1]

advantages, returns = self.config.method.get_advantages_and_returns(
old_values, old_rewards, response_length
)

tokens, attention_mask, position_ids = self.get_model_inputs(
query_tensors, response_tensors
)

logits, *_, values_pred = self.model(
tokens, attention_mask=attention_mask, position_ids=position_ids
)
values_pred = values_pred[:, :-1]
logprobs = logprobs_from_logits(logits[:, :-1, :], tokens[:, 1:])
attention_mask = attention_mask[:, :-1]

# Only the response part of the values/logprobs is needed
start = query_tensors.shape[1] - 1
end = start + response_length
logprobs, values_pred, mask = (
logprobs[:, -response_length:],
values_pred[:, -response_length:],
attention_mask[:, -response_length:],
logprobs[:, start:end],
values_pred[:, start:end],
attention_mask[:, start:end],
)

loss, stats = self.config.method.loss(
Expand Down
24 changes: 14 additions & 10 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
hf_get_num_hidden_layers,
make_head,
whiten,
get_tensor_stats,
)


Expand Down Expand Up @@ -162,10 +163,12 @@ def loss(
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
n = mask.sum()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure but shouldn't n be vector valued where each component is the size of the ith generation? (Before endoftext)

Copy link
Collaborator Author

@maxreciprocate maxreciprocate Dec 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is used only in reductions to scalars so no


vf_loss1 = (values - returns) ** 2
vf_loss2 = (values_clipped - returns) ** 2
vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
vf_clipfrac = torch.mean((vf_loss2 > vf_loss1).float())
vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / n
vf_clipfrac = torch.mean((vf_loss2 > vf_loss1).float() * mask) / n

log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio)
Expand All @@ -179,8 +182,8 @@ def loss(
1.0 - self.cliprange,
1.0 + self.cliprange,
)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
pg_clipfrac = torch.mean((pg_loss2 > pg_loss1).float())
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / n
pg_clipfrac = torch.sum((pg_loss2 > pg_loss1).float() * mask) / n

loss = pg_loss + self.vf_coef * vf_loss

Expand All @@ -191,16 +194,17 @@ def loss(
value_loss=vf_loss.item(),
),
values=dict(
mean_old_values=torch.mean(old_values),
var_old_values=torch.var(old_values),
mean_values=torch.mean(values),
values_error=torch.mean((values - returns) ** 2),
get_tensor_stats(values, mask, n),
values_error=torch.sum(((values - returns) * mask) ** 2) / n,
clipfrac=vf_clipfrac,
),
old_values=get_tensor_stats(old_values, mask, n),
returns=get_tensor_stats(returns, mask, n),
policy=dict(approx_kl=approx_kl.item(), clipfrac=pg_clipfrac.item()),
returns=dict(mean=torch.mean(returns), var=torch.var(returns)),
ratio=(ratio * mask).sum() / mask.sum(),
ratio=(ratio * mask).sum() / n,
perc_padding=n / mask.numel(),
)

return loss, flatten_dict(stats)


Expand Down
48 changes: 25 additions & 23 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
samples, skip_special_tokens=True
)
exp_score_time = time()
scores = torch.as_tensor(self.score(texts), device=samples.device)
scores = torch.tensor(self.score(texts), device=samples.device)
stats["exp_score_time"] = time() - exp_score_time

# store statistics of the initial rollout as reference
Expand All @@ -105,7 +105,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
query_tensors.to(response_tensors.device), response_tensors
)
with torch.no_grad():
logits, *_, v = self.rl_model.model(
logits, *_, values = self.rl_model.model(
all_tokens, attention_mask=attention_mask, position_ids=position_ids
)
# TODO(dahoas): When hydra model works need to also support generation on hydra head
Expand All @@ -122,43 +122,45 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
attention_mask=attention_mask.cpu(),
position_ids=position_ids.cpu(),
)
ref_logits = ref_logits.to(self.rl_model.accelerator.device)

ref_logits = ref_logits.to(self.rl_model.accelerator.device)
logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])
ref_logprobs = logprobs_from_logits(
ref_logits[:, :-1, :], all_tokens[:, 1:]
)
start = query_tensors.size()[1] - 1
end = query_tensors.size()[1] + response_tensors.size()[1] - 1
all_values = v[:, start:end]
all_logprobs = logprobs[:, start:end]
all_ref_logprobs = ref_logprobs[:, start:end]
values = values[:, :-1]

n = samples.shape[0]
start = query_tensors.shape[1] - 1
ends = start + attention_mask[:, start:].sum(1)
all_values = [values[ix, start : ends[ix]] for ix in range(n)]
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n)]
all_ref_logprobs = [ref_logprobs[ix, start : ends[ix]] for ix in range(n)]

# Compute rewards
kls = all_logprobs - all_ref_logprobs
non_score_rewards = -self.rl_model.kl_ctl.value * kls
all_rewards = non_score_rewards.clone()
all_rewards[:, -1] += scores.to(self.rl_model.accelerator.device)
rewards = -self.rl_model.kl_ctl.value * (logprobs - ref_logprobs)
all_rewards = [None] * n
for ix in range(n):
rs = rewards[ix][start : ends[ix]]
rs[-1] = scores[ix]
all_rewards[ix] = rs

query_tensors = query_tensors.cpu()
response_tensors = response_tensors.cpu()
all_logprobs = all_logprobs.cpu()
all_values = all_values.cpu()
all_rewards = all_rewards.cpu()

exp_time = clock.tick()

new_ppo_rl_elements = [
PPORLElement(
query_tensor=query_tensors[i, :],
response_tensor=response_tensors[i, :],
logprobs=all_logprobs[i, :],
values=all_values[i, :],
rewards=all_rewards[i, :],
query_tensor=query_tensors[i],
response_tensor=response_tensors[i],
logprobs=all_logprobs[i],
values=all_values[i],
rewards=all_rewards[i],
)
for i in range(query_tensors.size()[0])
for i in range(n)
]

ppo_rl_elements += new_ppo_rl_elements
exp_time = clock.tick()

stats["kl_ctl_value"] = self.rl_model.kl_ctl.value
stats["exp_time"] = exp_time
Expand Down
15 changes: 7 additions & 8 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.distributed as dist
import transformers
from typing import Tuple
import numpy as np


def make_head(n_embd: int, out: int) -> nn.Sequential:
Expand Down Expand Up @@ -198,15 +199,13 @@ def flatten_dict(
return dict(items)


def log_stat(stats: dict, name: str, xs: torch.Tensor, mask: torch.Tensor, n: int):
def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int):
mean = (xs * mask).sum() / n
stats.update(
{
f"{name}/mean": mean,
f"{name}/min": torch.where(mask.bool(), xs, np.inf).min(),
f"{name}/max": torch.where(mask.bool(), xs, -np.inf).max(),
f"{name}/std": torch.sqrt(((xs - mean) * mask).pow(2).sum() / n),
}
return dict(
mean=mean,
min=torch.where(mask.bool(), xs, np.inf).min(),
max=torch.where(mask.bool(), xs, -np.inf).max(),
std=torch.sqrt(((xs - mean) * mask).pow(2).sum() / n),
)


Expand Down