Skip to content

Commit 84dd156

Browse files
Update generation utilities (#172)
* feat(base_trainer): enable sweeping over a single `gen_kwargs` value * refactor(base_trainer): rename relevant variables * fix(base_trainer): initialize `gen_sweep_arg` regardless * feat(base_trainer): change `reward_fn`'s signature to accept kwargs * merge(base_trainer): refactor to reflect main * feat(*_trainer): add `stop_word` * refactor(base_trainer): remove `seq2seq` if-case * refactor(base_trainer): clean up logging of samples * fix(base_trainer): remove inconsistencies * fix(ppo_orchestrator): consistent padding and gpu device * feat(base_trainer): add `rich` as dependency * chore(examples): update signatures * fix(ppo_orchestrator): logprob gather indexing * docs(trlx): update `train`'s signature * fix(base_trainer): disable `save_best` when training with deepspeed * merge(base): complete merge * feat(base_trainer): rework `stop_word` -> `stop_sequences` * docs(base_trainer): update `decode`'s signature * chore(base_trainer): `print` -> `print_rank_0` * feat(base_trainer): clean up table's output * feat(base_trainer): add number of gpus to the run's name * style(trlx): satisfy black * style(wandb): satisfy isort
1 parent 400dcfd commit 84dd156

16 files changed

+297
-176
lines changed

examples/architext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from trlx.data.configs import TRLConfig
77

88

9-
def reward_fn(samples):
9+
def reward_fn(samples, **kwargs):
1010
"Gives a negative count of rooms for each sample"
1111
return [-sample.count(":") for sample in samples]
1212

examples/ilql_sentiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def main(hparams={}):
2929
device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1,
3030
)
3131

32-
def metric_fn(samples: List[str]) -> Dict[str, List[float]]:
32+
def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]:
3333
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
3434
return {"sentiments": sentiments}
3535

examples/ppo_sentiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def main(hparams={}):
3838
device=device,
3939
)
4040

41-
def reward_fn(samples: List[str]) -> List[float]:
41+
def reward_fn(samples: List[str], **kwargs) -> List[float]:
4242
sentiments = list(map(get_positive_score, sentiment_fn(samples)))
4343
return sentiments
4444

examples/randomwalks/configs/ilql_randomwalks.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,6 @@ method:
4444
two_qs: true
4545
gen_kwargs:
4646
max_new_tokens: 9
47-
top_k: 1
48-
beta: 100
47+
top_k: 10
48+
beta: [0, 1, 100]
4949
temperature: 1.0

examples/randomwalks/ilql_randomwalks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ def main(hparams={}):
2323
GPT2Config(n_layer=6, n_embd=144, vocab_size=23),
2424
dataset=(walks, rewards),
2525
eval_prompts=eval_prompts,
26-
metric_fn=metric_fn,
26+
metric_fn=lambda samples, **kwargs: metric_fn(samples),
2727
config=config,
28+
stop_sequences=["|"],
2829
)
2930

3031

examples/randomwalks/ppo_randomwalks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ def main(hparams={}):
1717

1818
trlx.train(
1919
"CarperAI/randomwalks",
20-
reward_fn=lambda walks: metric_fn(walks)["optimality"],
20+
reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"],
2121
prompts=prompts,
2222
eval_prompts=prompts,
23-
metric_fn=metric_fn,
23+
metric_fn=lambda samples, **kwargs: metric_fn(samples),
2424
config=config,
2525
)
2626

examples/summarize_daily_cnn/t5_summarize_daily_cnn.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,14 @@
2525

2626
if __name__ == "__main__":
2727

28-
def reward_fn(samples: List[str]):
29-
sep_token = tokenizer.sep_token
30-
articles = [sample.split(sep_token)[0].strip() for sample in samples]
31-
predicted_summaries = [sample.split(sep_token)[1].strip() for sample in samples]
32-
labels = [prompt_label[sample] for sample in articles]
28+
def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
29+
original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]
3330
scores = [
34-
meteor.compute(predictions=[summary], references=[label])
35-
for (summary, label) in zip(predicted_summaries, labels)
31+
meteor.compute(predictions=[output.strip()], references=[original])[
32+
"meteor"
33+
]
34+
for (original, output) in zip(original_summaries, outputs)
3635
]
37-
scores = [score["meteor"] for score in scores]
3836
return scores
3937

4038
dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data")

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ install_requires =
1919
torchtyping
2020
transformers>=4.21.2
2121
tqdm
22+
rich
2223
wandb>=0.13.5
2324
ray>=2.0.1
2425
tabulate>=0.9.0

trlx/orchestrator/ppo_orchestrator.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from time import time
2-
from typing import Callable, Optional
32

43
import ray
54
import torch
5+
import torch.nn.functional as F
66

77
from trlx.data.accelerate_base_datatypes import PromptBatch
88
from trlx.data.ppo_types import PPORLElement
@@ -24,8 +24,6 @@ def __init__(
2424
self,
2525
trainer: BaseRLTrainer,
2626
pipeline: BasePipeline,
27-
reward_fn: Callable,
28-
metric_fn: Optional[Callable] = None,
2927
chunk_size: int = 512,
3028
):
3129
self.pipeline = pipeline
@@ -43,8 +41,6 @@ def __init__(
4341
self.ref_model.to(self.trainer.accelerator.device)
4442

4543
self.trainer.orch = self
46-
self.trainer.reward_fn = reward_fn
47-
self.trainer.metric_fn = metric_fn
4844

4945
self.running = RunningMoments()
5046
self.ref_mean = self.trainer.config.method.ref_mean
@@ -65,9 +61,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
6561
stats = {}
6662
clock = Clock()
6763
while len(ppo_rl_elements) < num_rollouts:
68-
if self.trainer.accelerator.is_main_process:
69-
print(f"Making experience {len(ppo_rl_elements)} / {num_rollouts}")
70-
7164
# Get next batch in prompt dataset and refresh if exhausted
7265
try:
7366
batch: PromptBatch = next(self.pipeline_iterator)
@@ -79,30 +72,38 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
7972
samples = self.trainer.generate(**batch)
8073
stats["time/exp_generate"] = time() - exp_generate_time
8174

82-
if self.trainer.config.model.model_arch_type == "seq2seq":
83-
response_tensors = samples
84-
else:
85-
query_tensors = batch.input_ids
86-
response_tensors = samples[:, query_tensors.shape[1] :]
87-
88-
texts = self.trainer.tokenizer.batch_decode(
89-
samples, skip_special_tokens=True
75+
query_tensors = batch.input_ids
76+
device = samples.device
77+
str_samples, str_prompts, str_outputs = self.trainer.decode(
78+
query_tensors, samples
9079
)
9180

92-
if self.trainer.config.model.model_arch_type == "seq2seq":
93-
articles = self.trainer.tokenizer.batch_decode(
94-
batch.input_ids, skip_special_tokens=True
81+
# Convert trimmed samples back into tensors for another head pass
82+
# This can be defered, instead letting the pass to made over the original samples
83+
# after unbinding and truncating operations lower are fixed
84+
outputs = self.trainer.tokenizer(str_outputs).input_ids
85+
outputs = list(map(torch.LongTensor, outputs))
86+
maxsize = max(map(len, outputs))
87+
outputs = [
88+
F.pad(
89+
output,
90+
(0, maxsize - len(output)),
91+
value=self.trainer.tokenizer.pad_token_id,
9592
)
96-
sep_token = self.trainer.tokenizer.sep_token
97-
texts = [
98-
f"{article}{sep_token}{response}"
99-
for article, response in zip(articles, texts)
100-
]
93+
for output in outputs
94+
]
95+
response_tensors = torch.vstack(outputs).to(device)
10196

10297
exp_score_time = time()
98+
10399
scores = torch.tensor(
104-
self.score(texts), device=samples.device, dtype=torch.float
105-
)
100+
self.trainer.reward_fn(
101+
samples=str_samples,
102+
prompts=str_prompts,
103+
outputs=str_outputs,
104+
),
105+
dtype=float,
106+
).to(device)
106107
stats["time/exp_score"] = time() - exp_score_time
107108

108109
# store statistics of the initial rollout as reference
@@ -125,9 +126,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
125126

126127
# Precompute logprobs, values
127128
if self.trainer.config.model.model_arch_type == "seq2seq":
128-
response_tensors = response_tensors
129-
attention_mask = batch.attention_mask.to(response_tensors.device)
130-
query_tensors = batch.input_ids.to(response_tensors.device)
129+
attention_mask = batch.attention_mask.to(device)
130+
query_tensors = batch.input_ids.to(device)
131131
with torch.no_grad():
132132
outputs = self.trainer.model(
133133
input_ids=query_tensors,
@@ -150,12 +150,12 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
150150
).logits
151151
else:
152152
all_tokens = torch.cat(
153-
(query_tensors.to(response_tensors.device), response_tensors), dim=1
153+
(query_tensors.to(device), response_tensors), dim=1
154154
)
155155
attention_mask = (
156156
all_tokens.not_equal(self.trainer.tokenizer.pad_token_id)
157157
.long()
158-
.to(all_tokens.device)
158+
.to(device)
159159
)
160160
with torch.no_grad():
161161
logits, *_, values = self.trainer.model(
@@ -175,7 +175,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
175175
attention_mask=attention_mask,
176176
return_dict=False,
177177
)
178-
ref_logits = ref_logits.to(self.trainer.accelerator.device)
178+
ref_logits = ref_logits.to(device)
179179

180180
if self.trainer.config.model.model_arch_type == "seq2seq":
181181
logprobs = logprobs_from_logits(

trlx/ray_tune/wandb.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import wandb
99

10+
from trlx.utils import significant
11+
1012
import wandb.apis.reports as wb # isort: skip
1113

1214

@@ -39,10 +41,6 @@ def parse_result(result):
3941
return out
4042

4143

42-
def significant(x):
43-
return round(x, 1 - int(math.floor(math.log10(x))))
44-
45-
4644
def log_trials(trial_path: str, project_name: str):
4745
trial_path = Path(trial_path)
4846
files = os.listdir(trial_path)

0 commit comments

Comments
 (0)