|
| 1 | +from typing import Callable, Dict, List, Optional, Tuple |
| 2 | + |
1 | 3 | import networkx as nx
|
2 | 4 | import numpy as np
|
3 | 5 | import torch
|
4 | 6 |
|
5 | 7 |
|
6 |
| -def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int: |
| 8 | +def generate_rand_int_excluding( |
| 9 | + rng: np.random.RandomState, max: int, exclude: int |
| 10 | +) -> int: |
| 11 | + """Random integer generator, excluding a specific number |
| 12 | +
|
| 13 | + Args: |
| 14 | + rng: Numpy random number generator |
| 15 | + max: Max number |
| 16 | + exclude: Number to exclude |
| 17 | +
|
| 18 | + Returns: |
| 19 | + Random integer in [0, max], excluding the `exclude` integer. |
| 20 | + """ |
7 | 21 | while True:
|
8 |
| - x = rng.randint(n) |
| 22 | + # Create the random integer |
| 23 | + x = rng.randint(max) |
| 24 | + |
| 25 | + # Return the random integer if it isn't the exclude value, otherwise try |
| 26 | + # again |
9 | 27 | if x != exclude:
|
10 | 28 | return x
|
11 | 29 |
|
12 | 30 |
|
13 | 31 | def generate_random_walks( # noqa: max-complexity
|
14 |
| - n_nodes=21, max_length=10, n_walks=1000, p_edge=0.1, seed=1002, gpt2_tokenizer=False |
15 |
| -): |
| 32 | + n_nodes: int = 21, |
| 33 | + max_length: int = 10, |
| 34 | + n_walks: int = 1000, |
| 35 | + p_edge: float = 0.1, |
| 36 | + seed: int = 1002, |
| 37 | + gpt2_tokenizer: bool = False, |
| 38 | +) -> Tuple[ |
| 39 | + Callable[[List[str]], Dict[str, List[float]]], |
| 40 | + List[str], |
| 41 | + List[str], |
| 42 | + torch.Tensor, |
| 43 | +]: |
| 44 | + """Generate random walks |
| 45 | +
|
| 46 | + Args: |
| 47 | + n_nodes: Number of nodes. This should not be more than 26, as we use |
| 48 | + single letters to represent each node. |
| 49 | + max_length: Maximum number of steps in each random walk |
| 50 | + n_walks: Number of random walks (samples) to create |
| 51 | + p_edge: Probability that any source node connects to any other |
| 52 | + destination node |
| 53 | + seed: Random seed |
| 54 | + gpt2_tokenizer: True if GPT2's tokenizer is being used |
| 55 | +
|
| 56 | + Returns: |
| 57 | + Tuple of metric function, |
| 58 | + """ |
| 59 | + # Initialise a random state with the seed |
16 | 60 | rng = np.random.RandomState(seed)
|
17 | 61 |
|
| 62 | + # Create the adjacency matrix |
| 63 | + # https://en.wikipedia.org/wiki/Adjacency_matrix |
| 64 | + # This is a 2d matrix, where the rows represent the source nodes and the |
| 65 | + # columns represent the destination nodes. If a cell (i,j) is True, then |
| 66 | + # there is a directional edge from the source node (i) to the destination |
| 67 | + # node (j). If it is false there is no connection. |
18 | 68 | while True:
|
19 |
| - adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge) |
20 |
| - np.fill_diagonal(adj, 0) |
21 |
| - if np.all(adj.sum(1)): |
| 69 | + # Create the adjacency matrix, where each node is connected to each |
| 70 | + # other node, with probability p_edge |
| 71 | + adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge) |
| 72 | + |
| 73 | + # Nodes can't be connected to themselves, so the diagonal values must |
| 74 | + # all be False |
| 75 | + np.fill_diagonal(adjacency_matrix, 0) |
| 76 | + |
| 77 | + # Each destination node (column) must be connected to at least one |
| 78 | + # source node. This checks if this is the case, by checking there is a |
| 79 | + # True value in every column. If it is not the case, we try to generate |
| 80 | + # a new adjacency matrix again from scratch (in the while loop). |
| 81 | + if np.all(adjacency_matrix.sum(1)): |
22 | 82 | break
|
23 | 83 |
|
24 |
| - # terminal state |
25 |
| - adj[0, :] = 0 |
26 |
| - adj[0, 0] = 1 |
| 84 | + # Set the goal node as 0 |
| 85 | + goal: int = 0 |
27 | 86 |
|
28 |
| - char_to_node = {chr(ix + ord("a")): ix for ix in range(n_nodes)} |
29 |
| - node_to_char = {ix: chr(ix + ord("a")) for ix in range(n_nodes)} |
| 87 | + # The goal node is the terminal state, so we make sure that it doesn't |
| 88 | + # have a directional edge going to any other nodes (i.e. it can only be |
| 89 | + # connected to from previous nodes). We also set the connection to itself as |
| 90 | + # True. |
| 91 | + adjacency_matrix[goal, :] = 0 |
| 92 | + adjacency_matrix[goal, goal] = 1 |
30 | 93 |
|
31 |
| - goal = 0 |
32 |
| - sample_walks = [] |
33 |
| - delimiter = "|" if gpt2_tokenizer else "" |
| 94 | + # Create dicts for converting nodes into characters and vice versa |
| 95 | + # Nodes are converted into characters as these (when split by the delimiter) are |
| 96 | + # guaranteed to be tokenized as individual tokens. |
| 97 | + char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)} |
| 98 | + node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)} |
34 | 99 |
|
| 100 | + # Initialise a list of sample walks |
| 101 | + sample_walks: List[str] = [] |
| 102 | + |
| 103 | + # String delimiter (to force the tokenizer to keep all nodes as separate |
| 104 | + # tokens) |
| 105 | + delimiter: str = "|" if gpt2_tokenizer else "" |
| 106 | + |
| 107 | + # Create n_walks samples |
35 | 108 | for _ in range(n_walks):
|
36 |
| - node = randexclude(rng, n_nodes, goal) |
37 |
| - walk = [node] |
38 | 109 |
|
39 |
| - for istep in range(max_length - 1): |
40 |
| - node = rng.choice(np.nonzero(adj[node])[0]) |
41 |
| - walk.append(node) |
| 110 | + # Create a random starting node (that isn't already at the goal state) |
| 111 | + node: int = generate_rand_int_excluding(rng, n_nodes, goal) |
| 112 | + |
| 113 | + # Initialise the list of nodes that we visit |
| 114 | + walk_nodes: List[int] = [node] |
| 115 | + |
| 116 | + # Do a series of steps, until we hit the maximum number of steps or the |
| 117 | + # goal state (whichever comes first) |
| 118 | + for _step in range(max_length - 1): |
| 119 | + |
| 120 | + # From the starting node, get all the nodes we can move to. Pick one |
| 121 | + # of these at random, and add it to the list of visited nodes |
| 122 | + node = rng.choice(np.nonzero(adjacency_matrix[node])[0]) |
| 123 | + walk_nodes.append(node) |
| 124 | + |
| 125 | + # If we're at the goal state, stop |
42 | 126 | if node == goal:
|
43 | 127 | break
|
44 | 128 |
|
45 |
| - # code each node by a letter |
46 |
| - # for bpe tokenizer join them over | for a guaranteed split |
47 |
| - walk = [node_to_char[ix] for ix in walk] |
| 129 | + # Convert the nodes visited to letters (not integers) |
| 130 | + walk: List[str] = [node_to_char[ix] for ix in walk_nodes] |
48 | 131 |
|
| 132 | + # Concatenate into a journey, with each node letter separated by the |
| 133 | + # delimiter. |
49 | 134 | sample_walks.append(delimiter.join(walk))
|
50 | 135 |
|
51 |
| - # calculate the shortest paths for comparison |
52 |
| - shortest_lengths = [] |
53 |
| - g = nx.from_numpy_array(adj, create_using=nx.DiGraph) |
| 136 | + # Initialise list of shortest lengths for each node (to the goal node) |
| 137 | + shortest_lengths: List[int] = [] |
| 138 | + |
| 139 | + # Create a directional graph from the adjacency list |
| 140 | + directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph) |
| 141 | + |
| 142 | + # Fore each node (except for the goal node), find the shortest path |
54 | 143 | for start in set(range(n_nodes)) - {goal}:
|
55 | 144 | try:
|
56 |
| - shortest_path = nx.shortest_path(g, start, goal)[:max_length] |
| 145 | + # Find the shortest path (up to the max_length) |
| 146 | + shortest_path = nx.shortest_path(directional_graph, start, goal)[ |
| 147 | + :max_length |
| 148 | + ] |
57 | 149 | shortest_lengths.append(len(shortest_path))
|
58 | 150 | except Exception:
|
| 151 | + # If there is no path, use the maximum length instead |
59 | 152 | shortest_lengths.append(max_length)
|
60 | 153 |
|
61 |
| - shortest_lengths = torch.tensor(shortest_lengths) |
| 154 | + def metric_fn( |
| 155 | + samples: List[str], |
| 156 | + ) -> Dict[str, List[float]]: |
| 157 | + """Metric Function |
62 | 158 |
|
63 |
| - def metric_fn(samples): |
64 |
| - # a measure for an invalid or a not found path |
65 |
| - infty = 100 |
66 |
| - lengths = [] |
67 |
| - ref_lengths = [] |
| 159 | + Args: |
| 160 | + samples: Batch of samples |
68 | 161 |
|
69 |
| - for s in samples: |
| 162 | + Returns: |
| 163 | + Dict of metrics, each with a key of the metric name and value as a |
| 164 | + list of metric values for each batch item. |
| 165 | + """ |
| 166 | + # Length to set if the path is invalid |
| 167 | + invalid_path_length: int = 100 |
| 168 | + |
| 169 | + # Initialise batch lengths & reference lengths (the optimal length |
| 170 | + # starting from each batch items specific start node) |
| 171 | + lengths: List[float] = [] |
| 172 | + sample_optimal_lengths: List[int] = [] |
| 173 | + |
| 174 | + for sample_str in samples: |
| 175 | + # Remove GPT2 specific tokenizer delimiter |
70 | 176 | if gpt2_tokenizer:
|
71 |
| - s = s.replace("|", "") |
72 |
| - |
73 |
| - s = [char_to_node.get(c, 1000) for c in s] |
74 |
| - length = None |
75 |
| - for ix in range(len(s)): |
76 |
| - # a nonexisting path is taken |
77 |
| - if s[ix] >= n_nodes or ix > 0 and not adj[s[ix - 1], s[ix]]: |
78 |
| - length = infty |
| 177 | + sample_str = sample_str.replace("|", "") |
| 178 | + |
| 179 | + # Convert the sample into a list of nodes (default to an unused |
| 180 | + # integer if the node is not found) |
| 181 | + sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str] |
| 182 | + |
| 183 | + # Initialise the specific sample length |
| 184 | + length: Optional[float] = None |
| 185 | + |
| 186 | + for node in range(len(sample)): |
| 187 | + # If an invalid path is taken, set the length to the invalid |
| 188 | + # path score |
| 189 | + if ( |
| 190 | + sample[node] >= n_nodes |
| 191 | + or node > 0 |
| 192 | + and not adjacency_matrix[sample[node - 1], sample[node]] |
| 193 | + ): |
| 194 | + length = invalid_path_length |
79 | 195 | break
|
80 |
| - elif s[ix] == 0: |
81 |
| - length = ix + 1 |
| 196 | + |
| 197 | + # Otherwise increment the length for each move (where we don't |
| 198 | + # end up at the goal node) |
| 199 | + elif sample[node] == 0: |
| 200 | + length = node + 1 |
82 | 201 | break
|
83 | 202 |
|
| 203 | + # Catch the case where there are no moves |
84 | 204 | if length is None:
|
85 |
| - length = infty |
| 205 | + length = invalid_path_length |
| 206 | + |
| 207 | + # Store the batch item length & optimal length staring from the |
| 208 | + # start node |
| 209 | + lengths.append(float(length)) |
| 210 | + sample_optimal_lengths.append(shortest_lengths[sample[0] - 1]) |
86 | 211 |
|
87 |
| - lengths.append(length) |
88 |
| - # allows for inorder checking of % optimality |
89 |
| - ref_lengths.append(shortest_lengths[s[0] - 1]) |
| 212 | + # Calculate optimality scores, in [0, 1], as compared to the shortest |
| 213 | + # path |
| 214 | + lengths_tensor = torch.tensor(lengths, dtype=torch.float) |
| 215 | + bound_lengths: torch.Tensor = torch.where( |
| 216 | + lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor |
| 217 | + ).abs() |
| 218 | + optimal_lengths = torch.as_tensor(sample_optimal_lengths) |
90 | 219 |
|
91 |
| - lengths = torch.tensor(lengths, dtype=torch.float) |
92 |
| - bound_lengths = torch.where(lengths.eq(infty), max_length, lengths).abs() |
93 |
| - ref_lengths = torch.as_tensor(ref_lengths) |
| 220 | + # Optimality scores, in [0, 1], as compared to the shortest path |
| 221 | + optimality = (max_length - bound_lengths) / (max_length - optimal_lengths) |
94 | 222 |
|
95 | 223 | return {
|
96 | 224 | "lengths": lengths,
|
97 |
| - # percentage-optimal \in (0, 1) when compared to the shortest path |
98 |
| - "optimality": (max_length - bound_lengths) / (max_length - ref_lengths), |
| 225 | + "optimality": optimality.tolist(), |
99 | 226 | }
|
100 | 227 |
|
101 |
| - logit_mask = torch.tensor(adj) |
| 228 | + logit_mask = torch.tensor(adjacency_matrix) |
102 | 229 |
|
| 230 | + # Set the evaluation prompts as a list of unique random walk samples, using |
| 231 | + # just the start point (first character) from each samples. |
103 | 232 | eval_prompts = list(sorted(set(w[0] for w in sample_walks)))
|
104 | 233 | eval_prompts = [prompt + delimiter for prompt in eval_prompts]
|
105 | 234 |
|
106 |
| - return metric_fn, eval_prompts, sample_walks, logit_mask |
| 235 | + return (metric_fn, eval_prompts, sample_walks, logit_mask) |
0 commit comments