Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
dc6acb7
feat(ppo_models): context truncating generation
maxreciprocate Dec 8, 2022
ed4f6f3
feat(base_model): segragate prompts and responses in logging
maxreciprocate Dec 8, 2022
4674e3e
fix(ppo_model): truncate left padded tokens
maxreciprocate Dec 8, 2022
3b886df
revert(ppo_models): remove context overflowing generate()
maxreciprocate Dec 9, 2022
a8b6eaf
feat(configs): add max_new_tokens
maxreciprocate Dec 9, 2022
31e5bb5
fix(pipeline): truncate prompts
maxreciprocate Dec 9, 2022
59176ac
chore(base_model): remove whole samples from logging
maxreciprocate Dec 9, 2022
48899e0
chore(configs): update the rest of configs
maxreciprocate Dec 9, 2022
27bd892
fix(configs): update program synthesis config
maxreciprocate Dec 12, 2022
d6be611
fix(base_model): prompts sizes
maxreciprocate Dec 12, 2022
74e758a
revert(config): emulate old ppo_sentiments behavior
maxreciprocate Dec 12, 2022
7366ab6
fix(ppo): unequal generation lengths
maxreciprocate Dec 14, 2022
1b4d5db
chore(ppo): put indexing on cpu
maxreciprocate Dec 14, 2022
dd8b21a
revert(configl): old ilql_sentiments behavior
maxreciprocate Dec 14, 2022
d1a9c38
fix(ppo): unzero clipfrac
maxreciprocate Dec 14, 2022
97bdd72
Merge branch 'main' into fix-context-overflow
maxreciprocate Dec 14, 2022
f394897
merge(configs): delete old options
maxreciprocate Dec 14, 2022
a960ada
revert(config): old ppo_sentiment behavior
maxreciprocate Dec 14, 2022
a290dd8
refactor(ppo_orchestrator): remove unused ref_logprobs
maxreciprocate Dec 14, 2022
441dd23
fix(base_model): pin rewards to single precision
maxreciprocate Dec 15, 2022
330dde3
refactor(ppo): rename padding percentage
maxreciprocate Dec 15, 2022
6df0195
feat(wandb): add git branch name to tags
maxreciprocate Dec 15, 2022
005b348
refactor(wandb): logging name hierarchy
maxreciprocate Dec 15, 2022
cacebc3
revert(wandb): merge tags into a single string
maxreciprocate Dec 15, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion configs/ilql_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,9 @@ method:
awac_scale: 1
alpha: 0.001
steps_for_target_q_sync: 5
betas: [4]
two_qs: true
gen_kwargs:
max_new_tokens: 24
top_k: 20
beta: 4
temperature: 1.0
9 changes: 4 additions & 5 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ model:
num_layers_unfrozen: 2 # Number of bottom layers to freeze during training

train:
seq_length: 48 # Size of LM context
seq_length: 64 # Size of LM context
epochs: 100 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size
Expand Down Expand Up @@ -38,10 +38,9 @@ method:
scale_reward: False # False | "ref" | "running" estimate against which to scale rewards
ref_mean: null
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
cliprange_reward: 10 # clip reward into (-clip, clip) range
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
top_k: 0.0 # top k
max_new_tokens: 24 # model will generate at most this many tokens during rollouts
top_k: 0 # top k
top_p: 1.0 # top p
do_sample: True # sample
3 changes: 1 addition & 2 deletions configs/ppo_gptj.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ method:
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
max_new_tokens: 24
top_k: 0.0 # top k
top_p: 0.7 # top p
do_sample: True # sample
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ method:
ref_mean: null
ref_std: null
gen_kwargs:
max_length: 256 # LM max sample gen length
min_length: 48 # LM min sample gen length
max_new_tokens: 256 # model will generate at most this many tokens during rollouts
top_k: 0.0 # top k
top_p: 0.7 # top p
do_sample: True # sample
Expand Down
6 changes: 5 additions & 1 deletion examples/randomwalks/configs/ilql_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,9 @@ method:
awac_scale: 1
alpha: 0.1
steps_for_target_q_sync: 5
betas: [100]
two_qs: true
gen_kwargs:
max_new_tokens: 9
top_k: 1
beta: 100
temperature: 1.0
3 changes: 1 addition & 2 deletions examples/randomwalks/configs/ppo_randomwalks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ method:
ref_std: null
cliprange_reward: 1
gen_kwargs:
max_length: 10
min_length: 2
max_new_tokens: 9
top_k: 0.0
top_p: 1.0
do_sample: True
35 changes: 29 additions & 6 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def evaluate(self):
"""Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided"""
stats = {}
all_samples = []
prompts_sizes = []
generate_time = time()
for prompts in self.eval_dataloader:
if isinstance(prompts, torch.Tensor):
Expand All @@ -166,23 +167,45 @@ def evaluate(self):
value=pad_token,
)
)
sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(
len(prompts.input_ids)
)
prompts_sizes.append(sizes.to(samples.device))

stats["generate_time"] = time() - generate_time

samples = self.accelerator.gather(torch.vstack(all_samples))
prompts_sizes = self.accelerator.gather(torch.hstack(prompts_sizes))

if self.accelerator.is_main_process:
if self.tokenizer:
samples = self.tokenizer.batch_decode(samples, skip_special_tokens=True)
str_samples = self.tokenizer.batch_decode(
samples, skip_special_tokens=True
)

prompts, responses = [], []
for sample, prompt_size in zip(samples, prompts_sizes):
prompts.append(sample[:prompt_size])
responses.append(sample[prompt_size:])

if isinstance(samples[0], str):
columns_data = [samples]
str_prompts = self.tokenizer.batch_decode(
prompts, skip_special_tokens=True
)
str_responses = self.tokenizer.batch_decode(
responses, skip_special_tokens=True
)

if isinstance(str_samples[0], str):
columns_data = [str_prompts, str_responses]
else:
columns_data = [samples.tolist()]
columns = ["samples"]
columns = ["prompt", "response"]

# in online setting, compute the reward for validation
if self.reward_fn:
rewards = torch.as_tensor(self.reward_fn(samples), dtype=torch.float)
rewards = torch.as_tensor(
self.reward_fn(str_samples), dtype=torch.float
)
mean_reward = rewards.mean()
columns.append("reward")
columns_data.append(rewards)
Expand All @@ -192,7 +215,7 @@ def evaluate(self):
# additionally log any other metrics
if self.metric_fn:
metric_time = time()
metrics = self.metric_fn(samples)
metrics = self.metric_fn(str_samples)
stats["metric_time"] = time() - metric_time

mean_metrics = {
Expand Down
16 changes: 8 additions & 8 deletions trlx/model/accelerate_ilql_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def __init__(

self.ilql: ILQLConfig = cast(ILQLConfig, config.method)

self.generate_kwargs = dict(
config.method.gen_kwargs,
max_length=self.max_length,
logit_mask=self.logit_mask,
eos_token_id=self.tokenizer.eos_token_id if self.tokenizer else 0,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0,
)

def get_arch(self, config):
return CausalLMWithValueHeads(
config.model.model_path,
Expand Down Expand Up @@ -87,11 +95,3 @@ def prepare_learning(self):
self.n_updates_per_batch = 1
self.total_steps = self.config.train.epochs * len(train_dataloader)
self.total_steps = min(self.total_steps, self.config.train.total_steps)

self.generate_kwargs = {
"beta": self.config.method.betas[0],
"max_length": self.max_length,
"logit_mask": self.logit_mask,
"eos_token_id": self.tokenizer.eos_token_id if self.tokenizer else 0,
"pad_token_id": self.tokenizer.pad_token_id if self.tokenizer else 0,
}
4 changes: 3 additions & 1 deletion trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def get_model_inputs(
query_tensors: TensorType["batch_size", "query_size"],
response_tensors: TensorType["batch_size", "response_size"],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
tokens = torch.cat((query_tensors, response_tensors), dim=1)
tokens = torch.cat((query_tensors, response_tensors), dim=1)[
:, -self.max_length :
]
attention_mask = (
tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device)
)
Expand Down
11 changes: 5 additions & 6 deletions trlx/model/nn/ilql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class ILQLConfig(MethodConfig):
awac_scale: float
alpha: float
steps_for_target_q_sync: float
betas: Sequence[float]
two_qs: bool
gen_kwargs: dict

def heads(self, hidden_size: int, vocab_size: int):
return ILQLHeads(self, hidden_size, vocab_size)
Expand Down Expand Up @@ -140,7 +140,6 @@ def forward(
states_ixs: torch.Tensor = None,
actions_ixs: torch.Tensor = None,
):

if states_ixs is not None:
states_hs = hs.gather(
dim=1, index=states_ixs.unsqueeze(-1).repeat(1, 1, hs.shape[-1])
Expand Down Expand Up @@ -260,7 +259,8 @@ def generate(
position_ids=None,
past_key_values=None,
beta=1,
max_length=32,
max_new_tokens=32,
max_length=1024,
temperature=1,
top_k=20,
logit_mask=None,
Expand All @@ -278,13 +278,12 @@ def generate(
position_ids.masked_fill_(attention_mask.eq(0), 0)

samples = input_ids.clone()
tensors = defaultdict(list)
n_new_tokens = max_length - input_ids.shape[1]
max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1])

finished = torch.zeros(
input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device
)
for _ in range(n_new_tokens):
for _ in range(max_new_tokens):
out = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
Expand Down
16 changes: 12 additions & 4 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable
from typing import Iterable, List

import torch
from torch.nn.utils.rnn import pad_sequence
Expand All @@ -12,13 +12,21 @@
@register_datapipeline
class PromptPipeline(BasePipeline):
"""
Tokenizes texts, and then pads them into batches
Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right
"""

def __init__(self, prompts, tokenizer=None):
def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer=None):
super().__init__()

if tokenizer:
prompts = tokenizer(prompts).input_ids

self.tokenizer = tokenizer
self.prompts = list(map(tokenizer if tokenizer else (lambda x: x), prompts))
self.prompts = [prompt[-max_prompt_length:] for prompt in prompts]
self.prompts = [
{"input_ids": prompt, "attention_mask": [1] * len(prompt)}
for prompt in self.prompts
]

def __getitem__(self, ix: int):
return self.prompts[ix]
Expand Down
15 changes: 12 additions & 3 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,18 @@ def train(
if eval_prompts is None:
eval_prompts = prompts[:batch_size]

pipeline = get_pipeline(config.train.pipeline)(prompts, model.tokenizer)
max_prompt_length = (
config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
)
pipeline = get_pipeline(config.train.pipeline)(
prompts, max_prompt_length, model.tokenizer
)
orch = get_orchestrator(config.train.orchestrator)(
model, pipeline, reward_fn=reward_fn, chunk_size=config.method.chunk_size
)
orch.make_experience(config.method.num_rollouts)
eval_pipeline = get_pipeline(config.train.pipeline)(
eval_prompts, model.tokenizer
eval_prompts, max_prompt_length, model.tokenizer
)
model.add_eval_pipeline(eval_pipeline)

Expand All @@ -79,10 +84,14 @@ def train(
)

batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1))
max_prompt_length = (
config.train.seq_length - config.method.gen_kwargs["max_new_tokens"]
)

if eval_prompts is None:
eval_prompts = [model.tokenizer.bos_token] * batch_size
eval_pipeline = get_pipeline(config.train.pipeline)(
eval_prompts, model.tokenizer
eval_prompts, max_prompt_length, model.tokenizer
)

orch = get_orchestrator(config.train.orchestrator)(
Expand Down