Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ec00c54
add t5 to trlx
Dec 20, 2022
dacb652
add t5 examples for sentiment
Dec 20, 2022
56a0a3c
add eval for t5
Dec 20, 2022
5feff9f
fix eval
Dec 20, 2022
ccfabde
remove old files
Dec 21, 2022
2674d24
remove bad files
Dec 21, 2022
6e43ea1
remove bad files
Dec 21, 2022
59c2cf5
fix incompatible with gpt model, add summarization code base
Dec 21, 2022
2c133b0
freeze frozen branch
Dec 21, 2022
c9ddfcf
Merge branch 'main' into add_t5
PhungVanDuy Dec 21, 2022
5f38a81
fix evaluation bug t5, add summarization cnn/daily mail example
Dec 25, 2022
17be682
update sentiment example
Dec 27, 2022
2d1a4dc
stable config sentiment
Dec 27, 2022
f9f85ba
add attention mask decoder
Dec 29, 2022
500099f
setting worked - flant5 two unfrozen small rollouts
Dec 31, 2022
b55a4e8
merge newest code from main
Jan 1, 2023
36a74e6
fix head nn, config cnn daily mail, remove sent examples
Jan 2, 2023
6baee0b
fix style, change model_arch_type, truncated tokenizer fixed
Jan 6, 2023
d2082a7
fix style
Jan 6, 2023
d2f6a1d
precommit changes
Jan 6, 2023
eaf9c94
fix ppo state values for t5
Jan 7, 2023
c03313a
Merge branch 'main' into add_t5
PhungVanDuy Jan 7, 2023
93cf3cc
fix style
Jan 7, 2023
8ac399b
remove sentiment example
Jan 7, 2023
fefa62b
fix typo
Jan 7, 2023
5ae1188
fix ppo for causal models, add save best, seperate rollouts/eval args
Jan 7, 2023
ea10837
add ppo sentiment
Jan 7, 2023
84f8b7b
fix rewards typo
Jan 8, 2023
03cc954
Merge branch 'main' into add_t5
PhungVanDuy Jan 8, 2023
347e314
merging with main
Jan 8, 2023
220c8f3
fix style
Jan 8, 2023
a0a43f8
add docstring for gen_kwargs_inference, save best
Jan 9, 2023
b86e3d4
add gen kwargs support for rollouts sampling
Jan 9, 2023
eb0b0cc
Make summarization example self-contained
jon-tow Jan 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/configs/ppo_config_cnn_daily.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
train:
seq_length: 612
epochs: 100
total_steps: 100000
batch_size: 12

checkpoint_interval: 10000
eval_interval: 500
save_best: False

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"
trainer: "AcceleratePPOTrainer"

model:
model_path: "google/flan-t5-large"
model_arch_type: "seq2seq"
tokenizer_path: "google/flan-t5-large"
num_layers_unfrozen: 2

optimizer:
name: "adamw"
kwargs:
lr: 1.0e-5
betas: [0.9, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6

scheduler:
name: "cosine_annealing"
kwargs:
T_max: 10000
eta_min: 1.0e-6

method:
name: "ppoconfig"
num_rollouts: 512
chunk_size: 12
ppo_epochs: 4
init_kl_coef: 0.05
target: 6
horizon: 10000
gamma: 0.99
lam: 0.95
cliprange: 0.2
cliprange_value: 0.2
vf_coef: 1.0
scale_reward: False
ref_mean: null
ref_std: null
cliprange_reward: 10
gen_kwargs:
max_new_tokens: 100
# top_k: 50
# top_p: 0.95
# do_sample: True
gen_inference_kwargs:
max_new_tokens: 100
58 changes: 0 additions & 58 deletions examples/ppo_sentiments.py

This file was deleted.

72 changes: 72 additions & 0 deletions examples/trlx_t5_summ_daily_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import List

import evaluate
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer

import trlx
from trlx.data.configs import TRLConfig

meteor = evaluate.load("meteor") # use meteor as the reward function

if __name__ == "__main__":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This this is an example we probably should have lots of comments.


def reward_fn(samples: List[str]):
sep_token = tokenizer.sep_token
articles = [sample.split(sep_token)[0].strip() for sample in samples]
predicted_summaries = [sample.split(sep_token)[1].strip() for sample in samples]
labels = [prompt_label[sample] for sample in articles]
scores = [
meteor.compute(predictions=[summary], references=[label])
for (summary, label) in zip(predicted_summaries, labels)
]
scores = [score["meteor"] for score in scores]
return scores

config = TRLConfig.load_yaml("configs/ppo_config_cnn_daily.yml")

# samples 10000 samples from the training set as prompts for training
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train", cache_dir="data")
prompts = dataset["article"][0:20000]
summaries = dataset["highlights"][0:20000]
prompts = ["Summarize: " + prompt for prompt in prompts]

# samples 100 samples from the validation set as prompts for evaluation
val_dataset = load_dataset(
"cnn_dailymail", "3.0.0", split="validation", cache_dir="data"
)
val_prompts = ["Summarize: " + prompt for prompt in val_dataset["article"][0:1000]]
val_summaries = val_dataset["highlights"][0:1000]

# make dictionary of prompts and labels to use for reward function
tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)
tokenizer.padding_side = "left"
tokenizer.truncation_side = "right"
tokenizer.sep_token = "<sep>"
prompt_label = {}
max_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]

for i in tqdm(range(len(prompts))):
key = tokenizer.decode(
tokenizer(prompts[i], truncation=True, max_length=max_length)["input_ids"],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = summaries[i]

for i in tqdm(range(len(val_prompts))):
key = tokenizer.decode(
tokenizer(val_prompts[i], truncation=True, max_length=max_length)[
"input_ids"
],
skip_special_tokens=True,
) # get prompt like trlx's prompt
prompt_label[key.strip()] = val_summaries[i]

model = trlx.train(
config.model.model_path,
reward_fn=reward_fn,
prompts=prompts,
eval_prompts=val_prompts,
config=config,
)
5 changes: 5 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ class ModelConfig:
:param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub)
:type tokenizer_path: str

:param model_arch_type: Type of model architecture. Either "causal" or "seq2seq"
:type model_arch_type: str

:param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning.
-1 means all layers are unfrozen.
:type num_layers_unfrozen: int
"""

model_path: str
tokenizer_path: str
model_arch_type: str = "causal"
num_layers_unfrozen: int = -1

@classmethod
Expand Down Expand Up @@ -151,6 +155,7 @@ class TrainConfig:

checkpoint_dir: str = "ckpts"
rollout_logging_dir: Optional[str] = None
save_best: bool = True

trackers: Tuple[str] = ("wandb",)
seed: int = 1000
Expand Down
Loading