Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions configs/ilql_randomwalks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
model:
model_path: "randomwalks/1M"
tokenizer_path: "randomwalks/1M"
model_type: "AccelerateILQLModel"
num_layers_unfrozen: -1

train:
seq_length: 10
batch_size: 100
epochs: 20
total_steps: 1000

lr_init: 2.0e-4
lr_target: 2.0e-4
opt_betas: [0.9, 0.95]
opt_eps: 1.0e-8
weight_decay: 1.0e-6

checkpoint_interval: 100000
eval_interval: 16

pipeline: "PromptPipeline"
orchestrator: "OfflineOrchestrator"
seed: 1000

method:
name: "ilqlconfig"
tau: 0.8
gamma: 0.99
cql_scale: 0.1
awac_scale: 1
alpha: 0.1
steps_for_target_q_sync: 5
betas: [100]
two_qs: true
47 changes: 47 additions & 0 deletions configs/ppo_randomwalks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
model:
model_path: "randomwalks/1M"
tokenizer_path: "randomwalks/1M"
model_type: "AcceleratePPOModel"
num_layers_unfrozen: -1

train:
seq_length: 10
batch_size: 100
epochs: 20
total_steps: 1000

lr_init: 4.0e-4
lr_target: 4.0e-4
opt_betas: [0.9, 0.95]
opt_eps: 1.0e-8
weight_decay: 1.0e-6

checkpoint_interval: 10000
eval_interval: 20

pipeline: "PromptPipeline"
orchestrator: "PPOOrchestrator"

method:
name: 'ppoconfig'
num_rollouts: 128
chunk_size: 128
ppo_epochs: 4
init_kl_coef: 0.05
target: 6
horizon: 10000
gamma: 1
lam: 0.95
cliprange: 0.2
cliprange_value: 0.2
vf_coef: 1.2
scale_reward: False
ref_mean: null
ref_std: null
cliprange_reward: 1
gen_kwargs:
max_length: 10
min_length: 2
top_k: 0.0
top_p: 1.0
do_sample: True
109 changes: 0 additions & 109 deletions examples/randomwalks.py

This file was deleted.

12 changes: 12 additions & 0 deletions examples/randomwalks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Toy problem similar to the one described in [Decision Transformer (Lili Chen et al. 2021)](https://arxiv.org/abs/2106.01345) [1]:
finding graph's shortest paths by learning from a dataset of sampled random
walks.

In this implementation there are not environment dynamics – impossible and
incorrect paths are penalized the same way by a single reward which is given at
the end of the trajectory, measuring how optimal the path is compared to the
shortest possible (bounded in [0, 1]). PPO example uses a pretrained model for
starting transition probabilities, ILQL learns them from the samples directly.

[1] code for which is not present in the official repo, see issue
https://github.com/kzl/decision-transformer/issues/48
1 change: 1 addition & 0 deletions examples/randomwalks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .randomwalks import generate_random_walks
27 changes: 27 additions & 0 deletions examples/randomwalks/ilql_randomwalks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from examples.randomwalks import generate_random_walks

from transformers import GPT2Config
import trlx
from trlx.data.configs import TRLConfig
import yaml

default_config = yaml.safe_load(open("configs/ilql_randomwalks.yml"))


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

metric_fn, eval_prompts, walks, _ = generate_random_walks(seed=config.train.seed)
rewards = metric_fn(walks)["optimality"]

trlx.train(
GPT2Config(n_layer=6, n_embd=144, vocab_size=23),
dataset=(walks, rewards),
eval_prompts=eval_prompts,
metric_fn=metric_fn,
config=config,
)


if __name__ == "__main__":
main()
26 changes: 26 additions & 0 deletions examples/randomwalks/ppo_randomwalks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from examples.randomwalks import generate_random_walks

import yaml
import trlx
from trlx.data.configs import TRLConfig

default_config = yaml.safe_load(open("configs/ppo_randomwalks.yml"))


def main(hparams={}):
config = TRLConfig.update(default_config, hparams)

metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed)

trlx.train(
"randomwalks/1M",
reward_fn=lambda walks: metric_fn(walks)["optimality"],
prompts=prompts,
eval_prompts=prompts,
metric_fn=metric_fn,
config=config,
)


if __name__ == "__main__":
main()
105 changes: 105 additions & 0 deletions examples/randomwalks/randomwalks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import networkx as nx
import numpy as np
import torch


def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int:
while True:
x = rng.randint(n)
if x != exclude:
return x


def generate_random_walks(
n_nodes=21, max_length=10, n_walks=1000, p_edge=0.1, seed=1002, gpt2_tokenizer=False
):
rng = np.random.RandomState(seed)

while True:
adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
np.fill_diagonal(adj, 0)
if np.all(adj.sum(1)):
break

# terminal state
adj[0, :] = 0
adj[0, 0] = 1

char_to_node = {chr(ix + ord("a")): ix for ix in range(n_nodes)}
node_to_char = {ix: chr(ix + ord("a")) for ix in range(n_nodes)}

goal = 0
sample_walks = []
for _ in range(n_walks):
node = randexclude(rng, n_nodes, goal)
walk = [node]

for istep in range(max_length - 1):
node = rng.choice(np.nonzero(adj[node])[0])
walk.append(node)
if node == goal:
break

# code each node by a letter
# for bpe tokenizer join them over | for a guaranteed split
walk = [node_to_char[ix] for ix in walk]
delimiter = "|" if gpt2_tokenizer else ""

sample_walks.append(delimiter.join(walk))

# calculate the shortest paths for comparison
shortest_lengths = []
g = nx.from_numpy_array(adj, create_using=nx.DiGraph)
for start in set(range(n_nodes)) - {goal}:
try:
shortest_path = nx.shortest_path(g, start, goal)[:max_length]
shortest_lengths.append(len(shortest_path))
except Exception:
shortest_lengths.append(max_length)

shortest_lengths = torch.tensor(shortest_lengths)

def metric_fn(samples):
# a measure for an invalid or a not found path
infty = 100
lengths = []
ref_lengths = []

for s in samples:
if gpt2_tokenizer:
s = s.replace("|", "")

s = [char_to_node.get(c, 1000) for c in s]
length = None
for ix in range(len(s)):
# a nonexisting path is taken
if s[ix] >= n_nodes or ix > 0 and not adj[s[ix - 1], s[ix]]:
length = infty
break
elif s[ix] == 0:
length = ix + 1
break

if length is None:
length = infty

lengths.append(length)
# allows for inorder checking of % optimality
ref_lengths.append(shortest_lengths[s[0] - 1])

lengths = torch.tensor(lengths, dtype=torch.float)
bound_lengths = torch.where(lengths.eq(infty), max_length, lengths).abs()
ref_lengths = torch.as_tensor(ref_lengths)

return {
"lengths": lengths,
# percentage-optimal \in (0, 1) when compared to the shortest path
"optimality": (max_length - bound_lengths) / (max_length - ref_lengths),
}

logit_mask = torch.tensor(adj)

eval_prompts = list(sorted(set(w[0] for w in sample_walks)))
eval_prompts = [prompt + delimiter for prompt in eval_prompts]

return metric_fn, eval_prompts, sample_walks, logit_mask
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[metadata]
name = trlx
author = Alex Havrilla
version = 0.2.0
version = 0.3.0
url = https://github.com/CarperAI/trlx
description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
long_description = file: README.md
Expand Down Expand Up @@ -34,6 +34,5 @@ dev =

[options.packages.find]
exclude =
examples*
docs*
tests*