Skip to content

Commit 7f3a4ca

Browse files
fix(ppo): generalize and stage for api
1 parent 8498a87 commit 7f3a4ca

File tree

10 files changed

+161
-185
lines changed

10 files changed

+161
-185
lines changed

examples/ppo_sentiments.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,25 @@
1-
from typing import List
1+
import trlx
22

3-
import torch
3+
from datasets import load_dataset
44
from transformers import pipeline
55

6-
import wandb
7-
from trlx.data.configs import TRLConfig
8-
from trlx.model.accelerate_ppo_model import AcceleratePPOModel
9-
from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator
10-
from trlx.pipeline.ppo_pipeline import PPOPipeline
11-
from trlx.utils.loading import get_model, get_orchestrator, get_pipeline
12-
136
if __name__ == "__main__":
14-
cfg = TRLConfig.load_yaml("configs/ppo_config.yml")
15-
16-
sentiment_pipe = pipeline(
17-
"sentiment-analysis", "lvwerra/distilbert-imdb", device=-1
7+
sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb")
8+
9+
def reward_fn(samples):
10+
outputs = sentiment_fn(samples, return_all_scores=True)
11+
sentiments = [output[1]["score"] for output in outputs]
12+
return sentiments
13+
14+
# Take few words off of movies reviews as prompts
15+
imdb = load_dataset("imdb", split="train+test")
16+
prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
17+
18+
model = trlx.train(
19+
model_path="lvwerra/gpt2-imdb",
20+
reward_fn=reward_fn,
21+
prompts=prompts,
22+
eval_prompts=["I don't know much about Hungarian underground"] * 64 + ["<|endoftext|>"] * 64
1823
)
1924

20-
def reward_fn(samples: List[str]):
21-
sent_kwargs = {
22-
"return_all_scores": True,
23-
"function_to_apply": None,
24-
"batch_size": cfg.method.chunk_size,
25-
}
26-
pipe_outputs = sentiment_pipe(samples, **sent_kwargs)
27-
scores = torch.tensor([output[1]["score"] for output in pipe_outputs])
28-
return scores
29-
30-
model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg)
31-
if model.accelerator.is_main_process:
32-
wandb.watch(model.model)
33-
34-
pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg)
35-
orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)(
36-
model, pipeline, reward_fn=reward_fn, chunk_size=cfg.method.chunk_size
37-
)
38-
orch.make_experience(cfg.method.num_rollouts)
39-
model.learn()
4025

41-
print("DONE!")

trlx/model/accelerate_base_model.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,16 @@ class AccelerateRLModel(BaseRLModel):
2626
def __init__(self, config, train_mode=True):
2727
super().__init__(config, train_mode)
2828

29-
self.store = rollout_storage # Need to pass in rollout_storage to be loaded into accelerate object
29+
self.accelerator = Accelerator(log_with="wandb")
30+
31+
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
32+
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])
33+
else:
34+
torch.random.manual_seed(1000)
35+
36+
# Retrieves model equipped for ppo, ilql, etc
37+
self.model = self.get_arch(self.config)
3038

31-
self.model = self.get_arch(
32-
self.config
33-
) # Retrieves model equipped for ppo, ilql, etc
3439
if self.config.model.num_layers_unfrozen > 0:
3540
for block in self.model.gpt.transformer.h[:-self.config.model.num_layers_unfrozen]:
3641
for parameter in block.parameters():
@@ -43,19 +48,13 @@ def __init__(self, config, train_mode=True):
4348
else:
4449
self.tokenizer = None
4550

46-
self.max_length = config.train.gen_size
4751
config_dict = self.config.to_dict()
4852
if self.config.train.accelerate_config_path != "":
4953
with open(self.config.train.accelerate_config_path, mode="r") as file:
5054
accelerate_config = yaml.safe_load(file)
5155
config_dict.update(accelerate_config)
5256

53-
self.accelerator = Accelerator(log_with="wandb")
54-
55-
if int(os.environ.get("WORLD_SIZE", 1)) > 1:
56-
torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))])
57-
else:
58-
torch.random.manual_seed(1000)
57+
self.max_length = config.train.gen_size
5958

6059
if self.accelerator.is_main_process:
6160
self.accelerator.init_trackers(
@@ -97,33 +96,21 @@ def tokenize(self, text: Iterable[str]):
9796
)
9897

9998
def act(
100-
self, data: PromptBatch
99+
self, prompts
101100
) -> Tuple[
102101
TensorType["chunk_size", "input_length"],
103102
TensorType["chunk_size", "gen_size"],
104103
Iterable[str],
105104
]:
106-
query_tensors = data.tokens.to(
107-
self.accelerator.device
108-
) # [B, N] #TODO(dahoas): This may need to be changed
109105
with torch.no_grad():
110-
# TODO(dahoas): swap this out for custom generate to if this fixes issue
111-
_ = self.model(
112-
self.dummy_input.to(self.accelerator.device)
113-
) # Dummy pass to make things play nice with accelerate
114-
# Removed synced gpus
115-
response = self.model.generate(
116-
query_tensors,
117-
pad_token_id=self.tokenizer.eos_token_id,
118-
**self.config.method.gen_kwargs,
106+
samples = self.model.generate(
107+
**prompts,
108+
pad_token_id=self.tokenizer.pad_token_id,
109+
**self.config.method.gen_kwargs
119110
)
120-
response_tensors = response[
121-
:,
122-
query_tensors.size()[1] : query_tensors.size()[1]
123-
+ self.config.train.gen_size,
124-
]
125-
response_text = self.tokenizer.batch_decode(response_tensors)
126-
return query_tensors, response_tensors, response_text
111+
112+
texts = self.tokenizer.batch_decode(samples, skip_special_tokens=True)
113+
return prompts.input_ids, samples[:, prompts.input_ids.shape[1]:], texts
127114

128115
@torch.inference_mode()
129116
def sample(self, prompts: PromptBatch, gen_kwargs: dict) -> Iterable[str]:

trlx/model/accelerate_ilql_model.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import wandb
1313
from trlx.model import BaseRLModel, register_model
1414
from trlx.model.nn.ilql_models import CausalLMWithValueHeads
15-
from trlx.pipeline.offline_pipeline import (OfflinePipeline,
15+
from trlx.pipeline.offline_pipeline import (PromptPipeline,
1616
OfflineRolloutStorage)
1717

1818
from .accelerate_base_model import AccelerateRLModel
@@ -111,23 +111,6 @@ def learn(self):
111111
columns=["samples", *metrics.keys()], rows=rows
112112
)
113113

114-
metric_time = time()
115-
metrics = self.metric_fn(samples)
116-
metric_time = time() - metric_time
117-
logs.update({"metric_time": metric_time})
118-
119-
mean_metrics = {
120-
f"metrics/{k}/{beta}": torch.as_tensor(xs).mean(-1)
121-
for k, xs in metrics.items()
122-
}
123-
logs.update(tensor_stats)
124-
logs.update(mean_metrics)
125-
126-
rows = list(zip(samples, *metrics.values()))
127-
logs[f"samples/{beta}"] = wandb.Table(
128-
columns=["samples", *metrics.keys()], rows=rows
129-
)
130-
131114
print(rows[0])
132115
print(mean_metrics)
133116

trlx/model/accelerate_ppo_model.py

Lines changed: 45 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ def update(self, current, n_steps):
3333
mult = 1 + proportional_error * n_steps / self.horizon
3434
self.value *= mult
3535

36-
# Cell
37-
3836
class FixedKLController:
3937
"""Fixed KL controller."""
4038
def __init__(self, kl_coef):
@@ -48,20 +46,21 @@ class AcceleratePPOModel(AccelerateRLModel):
4846
def __init__(self, config, train_mode=True):
4947
super().__init__(config, train_mode)
5048

51-
self.store = PPORolloutStorage()
49+
self.store = PPORolloutStorage(self.tokenizer.pad_token_id)
5250

5351
rollout_loader = self.store.create_loader(
5452
self.config.train.batch_size, shuffle=True
5553
)
54+
5655
self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare(
5756
self.model, self.opt, self.scheduler, rollout_loader
5857
)
59-
self.store.clear_history()
6058

6159
self.dummy_input = self.tokenize("dummy input")[
6260
"input_ids"
6361
] # Hack to make acclerate distributed work with model generation
6462

63+
self.store.clear_history()
6564
if config.method.target is not None:
6665
self.kl_ctl = AdaptiveKLController(
6766
config.method.init_kl_coef,
@@ -78,6 +77,7 @@ def get_arch(self, config: TRLConfig):
7877
def loss(
7978
self, query_tensors, response_tensors, all_logprobs, all_values, all_rewards
8079
):
80+
8181
lastgaelam = 0
8282
advantages_reversed = []
8383
gen_len = response_tensors.shape[1]
@@ -99,7 +99,11 @@ def loss(
9999
advantages = advantages.detach()
100100

101101
all_tokens = torch.cat((query_tensors, response_tensors), dim=1)
102-
logits, _, vpred = self.model(all_tokens)
102+
attention_mask = all_tokens.not_equal(self.tokenizer.pad_token_id).long()
103+
position_ids = attention_mask.cumsum(-1) - 1
104+
position_ids.masked_fill_(attention_mask.eq(0), 0)
105+
106+
logits, _, vpred = self.model(all_tokens, attention_mask, position_ids=position_ids)
103107
logprob = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:])
104108

105109
# only the generation part of the values/logprobs is needed
@@ -111,9 +115,12 @@ def loss(
111115
all_values + self.config.method.cliprange_value,
112116
)
113117

118+
vf_mask = attention_mask[:, -gen_len-1:-1]
119+
pg_mask = attention_mask[:, -gen_len:]
120+
114121
vf_losses1 = (vpred - returns) ** 2
115122
vf_losses2 = (vpredclipped - returns) ** 2
116-
vf_loss = 0.5 * torch.mean(torch.max(vf_losses1, vf_losses2))
123+
vf_loss = 0.5 * torch.sum(torch.max(vf_losses1, vf_losses2) * vf_mask) / vf_mask.sum()
117124

118125
kl = logprob - all_logprobs
119126
# Record mean_kl for kl coef adjustment
@@ -127,57 +134,58 @@ def loss(
127134
1.0 + self.config.method.cliprange,
128135
)
129136

130-
pg_loss = torch.mean(torch.max(pg_losses, pg_losses2))
137+
pg_loss = torch.sum(torch.max(pg_losses, pg_losses2) * pg_mask) / pg_mask.sum()
131138

132139
model_loss = pg_loss + self.config.method.vf_coef * vf_loss
133140
return model_loss, pg_loss, vf_loss
134141

135142
def post_epoch_callback(self):
136-
# TODO(dahoas): are experiences being made for dataloaders on each process or same dataloader
137143
self.epoch += 1
138144
self.store.clear_history()
139145
self.orch.make_experience(
140146
self.config.method.num_rollouts, self.iter_count
141147
) # Collect more rollouts for training
142148

143149
def post_backward_callback(self):
144-
batch = self.logs["batch"]
145150
# Update kl_coefficient
146151
self.kl_ctl.update(self.mean_kl ,self.config.train.batch_size)
147-
# Run evaluation
152+
153+
all_samples = []
154+
for prompts in self.eval_dataloader:
155+
query, response, _ = self.act(prompts)
156+
pad_token = self.tokenizer.eos_token_id if self.tokenizer else 0
157+
samples = torch.hstack((query, response))
158+
all_samples.append(F.pad(samples, (0, self.max_length-samples.shape[1]), value=pad_token))
159+
160+
samples = self.accelerator.gather(torch.vstack(all_samples))
161+
148162
if self.accelerator.is_main_process:
149-
if (
150-
self.iter_count % self.config.train.eval_interval == 0
151-
or self.iter_count <= self.config.method.ppo_epochs
152-
):
153-
text = self.tokenizer.batch_decode(batch.query_tensors)
154-
eval_batch: PromptBatch = PromptBatch(
155-
text=text, tokens=batch.query_tensors
156-
)
157-
query_tensors, response_tensors, response_text = self.act(eval_batch)
158-
gen_texts = [q + r for q, r in zip(eval_batch.text, response_text)]
159-
scores = self.orch.score(gen_texts)
160-
mean_score = torch.mean(scores).item()
161-
rows = list(zip(gen_texts, scores.tolist()))
162-
stats = {
163-
"mean_score": mean_score,
164-
"responses": wandb.Table(columns=["response", "score"], rows=rows),
165-
"pg_loss": self.logs["pg_loss"],
166-
"vf_loss": self.logs["vf_loss"],
167-
"kl_coef": self.kl_ctl.value,
168-
}
169-
self.accelerator.log(stats, step=self.iter_count)
170-
self.accelerator.print(
171-
"Step: {}, Mean score: {}, pg_loss: {}, vf_loss: {}, kl_coef: {}".format(
172-
self.iter_count, mean_score, stats["pg_loss"], stats["vf_loss"], self.kl_ctl.value,
173-
)
163+
samples = self.tokenizer.batch_decode(samples, skip_special_tokens=True)
164+
scores = self.orch.score(samples)
165+
mean_score = torch.mean(torch.as_tensor(scores)).item()
166+
rows = list(zip(samples, scores))
167+
stats = {
168+
"mean_score": mean_score,
169+
"responses": wandb.Table(columns=["response", "score"], rows=rows),
170+
"pg_loss": self.logs["pg_loss"],
171+
"vf_loss": self.logs["vf_loss"],
172+
"kl_coef": self.kl_ctl.value,
173+
}
174+
175+
self.accelerator.log(stats, step=self.iter_count)
176+
self.accelerator.print(
177+
"Step: {}, Mean score: {}, pg_loss: {}, vf_loss: {}, kl_coef: {}".format(
178+
self.iter_count, mean_score, stats["pg_loss"], stats["vf_loss"], self.kl_ctl.value,
174179
)
180+
)
175181

176182
def learn(self):
183+
self.eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size)
184+
177185
rollout_loader = self.store.create_loader(
178186
self.config.train.batch_size, shuffle=True
179187
)
180-
rollout_loader = self.accelerator.prepare(rollout_loader)
188+
rollout_loader, self.eval_dataloader = self.accelerator.prepare(rollout_loader, self.eval_dataloader)
181189

182190
self.iter_count = 0
183191
self.epoch = 0
@@ -204,8 +212,7 @@ def learn(self):
204212
"batch": batch,
205213
"rewards": rewards,
206214
}
207-
# self.post_backward_callback()
208-
# exit()
215+
209216
self.opt.zero_grad()
210217
self.accelerator.backward(loss)
211218
self.opt.step()

trlx/model/nn/ilql_models.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,16 @@ def loss(self, batch):
137137

138138
_V = vs[:, :-1].squeeze()
139139
V = vs[:, 1:].squeeze() * dones[:, 1:]
140-
Q_ = rewards + self.gamma * V
140+
Q_ = rewards + self.gamma * V.detach()
141141

142142
if self.two_qs:
143-
loss_q1 = ((Q1 - Q_.detach()) * terminal_mask).pow(2).sum() / n_nonterminal
144-
loss_q2 = ((Q2 - Q_.detach()) * terminal_mask).pow(2).sum() / n_nonterminal
143+
loss_q1 = ((Q1 - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
144+
loss_q2 = ((Q2 - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
145145
loss_q = loss_q1 + loss_q2
146146
else:
147-
loss_q = ((Q - Q_.detach()) * terminal_mask).pow(2).sum() / n_nonterminal
147+
loss_q = ((Q - Q_) * terminal_mask).pow(2).sum() / n_nonterminal
148+
149+
targetQ = targetQ.detach()
148150

149151
loss_v = (
150152
(

0 commit comments

Comments
 (0)