Skip to content

Commit dcbf7b0

Browse files
authored
Improve documentation/comments on the random walk example (#208)
* Improve documentation for the Random walk example * Add additional notes on PPO random walks * Add image for documentation * Fix requested changes * Fix remaining issues
1 parent 4a62f04 commit dcbf7b0

File tree

5 files changed

+225
-72
lines changed

5 files changed

+225
-72
lines changed

examples/randomwalks/README.md

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,30 @@
1-
Toy problem similar to the one described in [Decision Transformer (Lili Chen et al. 2021)](https://arxiv.org/abs/2106.01345) [1]:
2-
finding graph's shortest paths by learning from a dataset of sampled random
3-
walks.
4-
5-
In this implementation there are not environment dynamics – impossible and
6-
incorrect paths are penalized the same way by a single reward which is given at
7-
the end of the trajectory, measuring how optimal the path is compared to the
8-
shortest possible (bounded in [0, 1]). Paths are represented as strings of
9-
letters, with each letter corresponding to a node in a graph. PPO example uses a
10-
pretrained model for starting transition probabilities, ILQL learns them from
11-
the samples directly.
12-
13-
[1] code for which is not present in the official repo, see issue
14-
https://github.com/kzl/decision-transformer/issues/48
1+
# Random Walks: Decision Tree Example
2+
3+
This example uses the Toy Problem described in [Decision Transformer (Lili Chen
4+
et al. 2021)](https://arxiv.org/abs/2106.01345).
5+
6+
## Game Description
7+
8+
The task is to find the shortest path on a directed graph. The reward is based
9+
on how optimal the path is compared to the shortest possible (bounded in [0,
10+
1]).
11+
12+
Note this is different to the paper, which gave rewards of -1 for every
13+
turn not at the goal state, and 0 at the goal state. Here the model instead
14+
receives its reward at the end of the full trajectory, based on how optimal it
15+
is compared to the minimum number of steps to reach the goal state (bounded in
16+
[0, 1]).
17+
18+
Paths are represented as strings of letters, with each letter corresponding to a
19+
node in the graph.
20+
21+
## Training
22+
23+
![Graph Example](graph-example.png)
24+
Source: Decision Transformer (Lili Chen et al. 2021)
25+
26+
For PPO, a language model was fine-tuned to predict the next token in a sequence
27+
of returns-to-go (sum of future rewards), states and actions. It was trained
28+
only on random walk data.
29+
30+
ILQL by contrast learns from the samples directly.
70.2 KB
Loading

examples/randomwalks/ppo_randomwalks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ def main(hparams={}):
1818

1919
trlx.train(
2020
"CarperAI/randomwalks",
21-
reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"],
21+
# An "optimality" reward function is used, with scores in [0,1]
22+
# depending on how close the path is to the shortest possible path.
23+
reward_fn=lambda samples, prompts, outputs: metric_fn(samples)["optimality"],
24+
# The prompts are simply the first nodes (represented as letters) to
25+
# start from.
2226
prompts=prompts,
2327
eval_prompts=prompts,
24-
metric_fn=lambda samples, **kwargs: metric_fn(samples),
28+
metric_fn=lambda samples, prompts, outputs: metric_fn(samples),
2529
config=config,
2630
)
2731

Lines changed: 184 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,235 @@
1+
from typing import Callable, Dict, List, Optional, Tuple
2+
13
import networkx as nx
24
import numpy as np
35
import torch
46

57

6-
def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int:
8+
def generate_rand_int_excluding(
9+
rng: np.random.RandomState, max: int, exclude: int
10+
) -> int:
11+
"""Random integer generator, excluding a specific number
12+
13+
Args:
14+
rng: Numpy random number generator
15+
max: Max number
16+
exclude: Number to exclude
17+
18+
Returns:
19+
Random integer in [0, max], excluding the `exclude` integer.
20+
"""
721
while True:
8-
x = rng.randint(n)
22+
# Create the random integer
23+
x = rng.randint(max)
24+
25+
# Return the random integer if it isn't the exclude value, otherwise try
26+
# again
927
if x != exclude:
1028
return x
1129

1230

1331
def generate_random_walks( # noqa: max-complexity
14-
n_nodes=21, max_length=10, n_walks=1000, p_edge=0.1, seed=1002, gpt2_tokenizer=False
15-
):
32+
n_nodes: int = 21,
33+
max_length: int = 10,
34+
n_walks: int = 1000,
35+
p_edge: float = 0.1,
36+
seed: int = 1002,
37+
gpt2_tokenizer: bool = False,
38+
) -> Tuple[
39+
Callable[[List[str]], Dict[str, List[float]]],
40+
List[str],
41+
List[str],
42+
torch.Tensor,
43+
]:
44+
"""Generate random walks
45+
46+
Args:
47+
n_nodes: Number of nodes. This should not be more than 26, as we use
48+
single letters to represent each node.
49+
max_length: Maximum number of steps in each random walk
50+
n_walks: Number of random walks (samples) to create
51+
p_edge: Probability that any source node connects to any other
52+
destination node
53+
seed: Random seed
54+
gpt2_tokenizer: True if GPT2's tokenizer is being used
55+
56+
Returns:
57+
Tuple of metric function,
58+
"""
59+
# Initialise a random state with the seed
1660
rng = np.random.RandomState(seed)
1761

62+
# Create the adjacency matrix
63+
# https://en.wikipedia.org/wiki/Adjacency_matrix
64+
# This is a 2d matrix, where the rows represent the source nodes and the
65+
# columns represent the destination nodes. If a cell (i,j) is True, then
66+
# there is a directional edge from the source node (i) to the destination
67+
# node (j). If it is false there is no connection.
1868
while True:
19-
adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
20-
np.fill_diagonal(adj, 0)
21-
if np.all(adj.sum(1)):
69+
# Create the adjacency matrix, where each node is connected to each
70+
# other node, with probability p_edge
71+
adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
72+
73+
# Nodes can't be connected to themselves, so the diagonal values must
74+
# all be False
75+
np.fill_diagonal(adjacency_matrix, 0)
76+
77+
# Each destination node (column) must be connected to at least one
78+
# source node. This checks if this is the case, by checking there is a
79+
# True value in every column. If it is not the case, we try to generate
80+
# a new adjacency matrix again from scratch (in the while loop).
81+
if np.all(adjacency_matrix.sum(1)):
2282
break
2383

24-
# terminal state
25-
adj[0, :] = 0
26-
adj[0, 0] = 1
84+
# Set the goal node as 0
85+
goal: int = 0
2786

28-
char_to_node = {chr(ix + ord("a")): ix for ix in range(n_nodes)}
29-
node_to_char = {ix: chr(ix + ord("a")) for ix in range(n_nodes)}
87+
# The goal node is the terminal state, so we make sure that it doesn't
88+
# have a directional edge going to any other nodes (i.e. it can only be
89+
# connected to from previous nodes). We also set the connection to itself as
90+
# True.
91+
adjacency_matrix[goal, :] = 0
92+
adjacency_matrix[goal, goal] = 1
3093

31-
goal = 0
32-
sample_walks = []
33-
delimiter = "|" if gpt2_tokenizer else ""
94+
# Create dicts for converting nodes into characters and vice versa
95+
# Nodes are converted into characters as these (when split by the delimiter) are
96+
# guaranteed to be tokenized as individual tokens.
97+
char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)}
98+
node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)}
3499

100+
# Initialise a list of sample walks
101+
sample_walks: List[str] = []
102+
103+
# String delimiter (to force the tokenizer to keep all nodes as separate
104+
# tokens)
105+
delimiter: str = "|" if gpt2_tokenizer else ""
106+
107+
# Create n_walks samples
35108
for _ in range(n_walks):
36-
node = randexclude(rng, n_nodes, goal)
37-
walk = [node]
38109

39-
for istep in range(max_length - 1):
40-
node = rng.choice(np.nonzero(adj[node])[0])
41-
walk.append(node)
110+
# Create a random starting node (that isn't already at the goal state)
111+
node: int = generate_rand_int_excluding(rng, n_nodes, goal)
112+
113+
# Initialise the list of nodes that we visit
114+
walk_nodes: List[int] = [node]
115+
116+
# Do a series of steps, until we hit the maximum number of steps or the
117+
# goal state (whichever comes first)
118+
for _step in range(max_length - 1):
119+
120+
# From the starting node, get all the nodes we can move to. Pick one
121+
# of these at random, and add it to the list of visited nodes
122+
node = rng.choice(np.nonzero(adjacency_matrix[node])[0])
123+
walk_nodes.append(node)
124+
125+
# If we're at the goal state, stop
42126
if node == goal:
43127
break
44128

45-
# code each node by a letter
46-
# for bpe tokenizer join them over | for a guaranteed split
47-
walk = [node_to_char[ix] for ix in walk]
129+
# Convert the nodes visited to letters (not integers)
130+
walk: List[str] = [node_to_char[ix] for ix in walk_nodes]
48131

132+
# Concatenate into a journey, with each node letter separated by the
133+
# delimiter.
49134
sample_walks.append(delimiter.join(walk))
50135

51-
# calculate the shortest paths for comparison
52-
shortest_lengths = []
53-
g = nx.from_numpy_array(adj, create_using=nx.DiGraph)
136+
# Initialise list of shortest lengths for each node (to the goal node)
137+
shortest_lengths: List[int] = []
138+
139+
# Create a directional graph from the adjacency list
140+
directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)
141+
142+
# Fore each node (except for the goal node), find the shortest path
54143
for start in set(range(n_nodes)) - {goal}:
55144
try:
56-
shortest_path = nx.shortest_path(g, start, goal)[:max_length]
145+
# Find the shortest path (up to the max_length)
146+
shortest_path = nx.shortest_path(directional_graph, start, goal)[
147+
:max_length
148+
]
57149
shortest_lengths.append(len(shortest_path))
58150
except Exception:
151+
# If there is no path, use the maximum length instead
59152
shortest_lengths.append(max_length)
60153

61-
shortest_lengths = torch.tensor(shortest_lengths)
154+
def metric_fn(
155+
samples: List[str],
156+
) -> Dict[str, List[float]]:
157+
"""Metric Function
62158
63-
def metric_fn(samples):
64-
# a measure for an invalid or a not found path
65-
infty = 100
66-
lengths = []
67-
ref_lengths = []
159+
Args:
160+
samples: Batch of samples
68161
69-
for s in samples:
162+
Returns:
163+
Dict of metrics, each with a key of the metric name and value as a
164+
list of metric values for each batch item.
165+
"""
166+
# Length to set if the path is invalid
167+
invalid_path_length: int = 100
168+
169+
# Initialise batch lengths & reference lengths (the optimal length
170+
# starting from each batch items specific start node)
171+
lengths: List[float] = []
172+
sample_optimal_lengths: List[int] = []
173+
174+
for sample_str in samples:
175+
# Remove GPT2 specific tokenizer delimiter
70176
if gpt2_tokenizer:
71-
s = s.replace("|", "")
72-
73-
s = [char_to_node.get(c, 1000) for c in s]
74-
length = None
75-
for ix in range(len(s)):
76-
# a nonexisting path is taken
77-
if s[ix] >= n_nodes or ix > 0 and not adj[s[ix - 1], s[ix]]:
78-
length = infty
177+
sample_str = sample_str.replace("|", "")
178+
179+
# Convert the sample into a list of nodes (default to an unused
180+
# integer if the node is not found)
181+
sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str]
182+
183+
# Initialise the specific sample length
184+
length: Optional[float] = None
185+
186+
for node in range(len(sample)):
187+
# If an invalid path is taken, set the length to the invalid
188+
# path score
189+
if (
190+
sample[node] >= n_nodes
191+
or node > 0
192+
and not adjacency_matrix[sample[node - 1], sample[node]]
193+
):
194+
length = invalid_path_length
79195
break
80-
elif s[ix] == 0:
81-
length = ix + 1
196+
197+
# Otherwise increment the length for each move (where we don't
198+
# end up at the goal node)
199+
elif sample[node] == 0:
200+
length = node + 1
82201
break
83202

203+
# Catch the case where there are no moves
84204
if length is None:
85-
length = infty
205+
length = invalid_path_length
206+
207+
# Store the batch item length & optimal length staring from the
208+
# start node
209+
lengths.append(float(length))
210+
sample_optimal_lengths.append(shortest_lengths[sample[0] - 1])
86211

87-
lengths.append(length)
88-
# allows for inorder checking of % optimality
89-
ref_lengths.append(shortest_lengths[s[0] - 1])
212+
# Calculate optimality scores, in [0, 1], as compared to the shortest
213+
# path
214+
lengths_tensor = torch.tensor(lengths, dtype=torch.float)
215+
bound_lengths: torch.Tensor = torch.where(
216+
lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor
217+
).abs()
218+
optimal_lengths = torch.as_tensor(sample_optimal_lengths)
90219

91-
lengths = torch.tensor(lengths, dtype=torch.float)
92-
bound_lengths = torch.where(lengths.eq(infty), max_length, lengths).abs()
93-
ref_lengths = torch.as_tensor(ref_lengths)
220+
# Optimality scores, in [0, 1], as compared to the shortest path
221+
optimality = (max_length - bound_lengths) / (max_length - optimal_lengths)
94222

95223
return {
96224
"lengths": lengths,
97-
# percentage-optimal \in (0, 1) when compared to the shortest path
98-
"optimality": (max_length - bound_lengths) / (max_length - ref_lengths),
225+
"optimality": optimality.tolist(),
99226
}
100227

101-
logit_mask = torch.tensor(adj)
228+
logit_mask = torch.tensor(adjacency_matrix)
102229

230+
# Set the evaluation prompts as a list of unique random walk samples, using
231+
# just the start point (first character) from each samples.
103232
eval_prompts = list(sorted(set(w[0] for w in sample_walks)))
104233
eval_prompts = [prompt + delimiter for prompt in eval_prompts]
105234

106-
return metric_fn, eval_prompts, sample_walks, logit_mask
235+
return (metric_fn, eval_prompts, sample_walks, logit_mask)

trlx/trainer/accelerate_base_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,11 @@ def evaluate(self): # noqa: C901
383383
# additionally log any other metrics
384384
if self.metric_fn:
385385
metric_time = time()
386-
metrics = self.metric_fn(str_samples)
386+
metrics = self.metric_fn(
387+
samples=str_samples,
388+
prompts=str_prompts,
389+
outputs=str_outputs,
390+
)
387391
stats["time/metric"] = time() - metric_time
388392

389393
mean_metrics = {

0 commit comments

Comments
 (0)