2
2
import numpy as np
3
3
import torch
4
4
5
- from .utils import SyntheticData , builder_from_single
5
+ from .utils import DataSegment , builder_from_single
6
6
7
7
8
8
def associative_recall (
@@ -54,7 +54,7 @@ def associative_recall(
54
54
Warning: If potential data leakage is detected between the train and test sets.
55
55
"""
56
56
57
- train_inputs , train_labels = _ar (
57
+ train = _ar (
58
58
vocab_size = vocab_size ,
59
59
num_examples = num_train_examples ,
60
60
input_seq_len = input_seq_len ,
@@ -63,7 +63,7 @@ def associative_recall(
63
63
num_queries = num_queries ,
64
64
random_non_queries = random_non_queries
65
65
)
66
- test_inputs , test_labels = _ar (
66
+ test = _ar (
67
67
vocab_size = vocab_size ,
68
68
num_examples = num_test_examples ,
69
69
input_seq_len = input_seq_len ,
@@ -163,113 +163,113 @@ def _ar(
163
163
return inputs , targets
164
164
165
165
166
- def multiquery_ar (
167
- vocab_size : int = 8_192 ,
168
- num_train_examples : int = 100_000 ,
169
- num_test_examples : int = 3_000 ,
170
- input_seq_len : int = 64 ,
171
- num_kv_pairs : int = 4 ,
172
- train_power_a : float = 0.01 ,
173
- test_power_a : float = 0.01 ,
174
- random_non_queries : bool = True ,
175
- seed : int = 0 ,
176
- ) -> SyntheticData :
177
- """
178
- Generates synthetic data for the multi-query associative recall task as described in
179
- Arora,Eyuboglu, et al. "Zoology: Measuring and improving recall in efficient language models.".
180
-
181
- Example:
182
- `multiquery_ar(vocab_size=12, num_kv_pairs=2, input_seq_len=16, random_non_queries=False)`
183
- will generate input and label sequences of the form:
166
+ # def multiquery_ar(
167
+ # vocab_size: int=8_192,
168
+ # num_train_examples: int=100_000,
169
+ # num_test_examples: int=3_000,
170
+ # input_seq_len: int=64,
171
+ # num_kv_pairs: int=4,
172
+ # train_power_a: float=0.01,
173
+ # test_power_a: float=0.01,
174
+ # random_non_queries: bool=True,
175
+ # seed: int=0,
176
+ # ) -> SyntheticData:
177
+ # """
178
+ # Generates synthetic data for the multi-query associative recall task as described in
179
+ # Arora,Eyuboglu, et al. "Zoology: Measuring and improving recall in efficient language models.".
180
+
181
+ # Example:
182
+ # `multiquery_ar(vocab_size=12, num_kv_pairs=2, input_seq_len=16, random_non_queries=False)`
183
+ # will generate input and label sequences of the form:
184
184
185
- Key Val Key Val Query Query
186
- Inputs: 2 8 4 7 0 0 4 0 0 0 0 0 2 0 0
187
- Labels: -100 -100 -100 -100 -100 -100 7 -100 -100 -100 -100 -100 8 -100 -100
185
+ # Key Val Key Val Query Query
186
+ # Inputs: 2 8 4 7 0 0 4 0 0 0 0 0 2 0 0
187
+ # Labels: -100 -100 -100 -100 -100 -100 7 -100 -100 -100 -100 -100 8 -100 -100
188
188
189
- The -100 labels are ignored by the loss function and metrics.
189
+ # The -100 labels are ignored by the loss function and metrics.
190
190
191
- We include one important note on the power law distribution. In real language data,
192
- the gap between repeated bigrams follows a power law. Intuitively, if the bigram
193
- "common buzzard" appears in text, the probability of the bigram appearing again
194
- drops the further away from the orginal mention we are. In our synthetic, we can
195
- control this with the power law parameters `train_power_a` and `test_power_a`.
196
- Setting these to 1.0 will result in a uniform distribution. You can visualize the
197
- distribution with the following code:
198
- ```
199
- space = 100
200
- power_a = 0.01
201
- p = power_a * np.arange(1, space + 1) ** (power_a-1)
202
- p = p / p.sum()
203
- plt.plot(p)
204
- ```
205
-
206
- Args:
207
- vocab_size (int): The size of the vocabulary. As discussed in the Zoology
208
- paper, large vocabulary sizes (>1k) can be important for highlighting
209
- differences between model architectures. Defaults to 8_192.
210
- num_train_examples (int): The number of training examples to generate. Defaults
211
- to 100_000.
212
- num_test_examples (int): The number of test examples to generate. Defaults to
213
- 3_000.
214
- input_seq_len (int): The length of the input sequence. Defaults to 64. In
215
- In Figure 2 of the Zoology paper, we vary the input sequence length from
216
- 64 to 512 and the number of key-value pairs from 4 to 64.
217
- seed (int): The seed for the random number generator.
218
- num_kv_pairs (int): The number of key-value pairs.
219
- train_power_a (float, optional): The power for the power law distribution for
220
- training data. Defaults to 0.01.
221
- test_power_a (float, optional): The power for the power law distribution for
222
- test data. Defaults to 0.01.
223
- random_non_queries (bool, optional): If True, replace all the 0's (as in the
224
- example above) with random values in the input. Defaults to True.
225
-
226
- Returns:
227
- SyntheticData: A SyntheticData object containing the generated train and test
228
- inputs and labels.
229
-
230
- Raises:
231
- Warning: If potential data leakage is detected between the train and test sets.
232
- """
233
-
234
- train_inputs , train_labels = _mqar (
235
- vocab_size = vocab_size ,
236
- num_examples = num_train_examples ,
237
- input_seq_len = input_seq_len ,
238
- seed = seed ,
239
- power_a = train_power_a ,
240
- num_kv_pairs = num_kv_pairs ,
241
- random_non_queries = random_non_queries
242
- )
243
- test_inputs , test_labels = _mqar (
244
- vocab_size = vocab_size ,
245
- num_examples = num_test_examples ,
246
- input_seq_len = input_seq_len ,
247
- seed = seed + 10 , # different seed for test set
248
- power_a = test_power_a ,
249
- num_kv_pairs = num_kv_pairs ,
250
- random_non_queries = random_non_queries
251
- )
252
-
253
- data = SyntheticData (
254
- train_inputs = train_inputs ,
255
- train_labels = train_labels ,
256
- test_inputs = test_inputs ,
257
- test_labels = test_labels ,
258
- )
259
-
260
- # check for data leakage:
261
- train_set = set ([" " .join (map (str , x )) for x in data .train_inputs .tolist ()])
262
- test_set = set ([" " .join (map (str , x )) for x in data .test_inputs .tolist ()])
263
- frac_test_in_train = 1 - (len (test_set - train_set ) / len (test_set ))
264
- if frac_test_in_train > 0.001 :
265
- print (
266
- "WARNING: Potential data leakage detected. "
267
- f"{ frac_test_in_train : 0.2f} of test examples are in the train set."
268
- )
269
- return data
191
+ # We include one important note on the power law distribution. In real language data,
192
+ # the gap between repeated bigrams follows a power law. Intuitively, if the bigram
193
+ # "common buzzard" appears in text, the probability of the bigram appearing again
194
+ # drops the further away from the orginal mention we are. In our synthetic, we can
195
+ # control this with the power law parameters `train_power_a` and `test_power_a`.
196
+ # Setting these to 1.0 will result in a uniform distribution. You can visualize the
197
+ # distribution with the following code:
198
+ # ```
199
+ # space = 100
200
+ # power_a = 0.01
201
+ # p = power_a * np.arange(1, space + 1) ** (power_a-1)
202
+ # p = p / p.sum()
203
+ # plt.plot(p)
204
+ # ```
205
+
206
+ # Args:
207
+ # vocab_size (int): The size of the vocabulary. As discussed in the Zoology
208
+ # paper, large vocabulary sizes (>1k) can be important for highlighting
209
+ # differences between model architectures. Defaults to 8_192.
210
+ # num_train_examples (int): The number of training examples to generate. Defaults
211
+ # to 100_000.
212
+ # num_test_examples (int): The number of test examples to generate. Defaults to
213
+ # 3_000.
214
+ # input_seq_len (int): The length of the input sequence. Defaults to 64. In
215
+ # In Figure 2 of the Zoology paper, we vary the input sequence length from
216
+ # 64 to 512 and the number of key-value pairs from 4 to 64.
217
+ # seed (int): The seed for the random number generator.
218
+ # num_kv_pairs (int): The number of key-value pairs.
219
+ # train_power_a (float, optional): The power for the power law distribution for
220
+ # training data. Defaults to 0.01.
221
+ # test_power_a (float, optional): The power for the power law distribution for
222
+ # test data. Defaults to 0.01.
223
+ # random_non_queries (bool, optional): If True, replace all the 0's (as in the
224
+ # example above) with random values in the input. Defaults to True.
225
+
226
+ # Returns:
227
+ # SyntheticData: A SyntheticData object containing the generated train and test
228
+ # inputs and labels.
229
+
230
+ # Raises:
231
+ # Warning: If potential data leakage is detected between the train and test sets.
232
+ # """
233
+
234
+ # train_inputs, train_labels = _mqar(
235
+ # vocab_size=vocab_size,
236
+ # num_examples=num_train_examples,
237
+ # input_seq_len=input_seq_len,
238
+ # seed=seed,
239
+ # power_a=train_power_a,
240
+ # num_kv_pairs=num_kv_pairs,
241
+ # random_non_queries=random_non_queries
242
+ # )
243
+ # test_inputs, test_labels = _mqar(
244
+ # vocab_size=vocab_size,
245
+ # num_examples=num_test_examples,
246
+ # input_seq_len=input_seq_len,
247
+ # seed=seed + 10, # different seed for test set
248
+ # power_a=test_power_a,
249
+ # num_kv_pairs=num_kv_pairs,
250
+ # random_non_queries=random_non_queries
251
+ # )
252
+
253
+ # data = SyntheticData(
254
+ # train_inputs=train_inputs,
255
+ # train_labels=train_labels,
256
+ # test_inputs=test_inputs,
257
+ # test_labels=test_labels,
258
+ # )
259
+
260
+ # # check for data leakage:
261
+ # train_set = set([" ".join(map(str, x)) for x in data.train_inputs.tolist()])
262
+ # test_set = set([" ".join(map(str, x)) for x in data.test_inputs.tolist()])
263
+ # frac_test_in_train = 1 - (len(test_set - train_set) / len(test_set))
264
+ # if frac_test_in_train > 0.001:
265
+ # print(
266
+ # "WARNING: Potential data leakage detected. "
267
+ # f"{frac_test_in_train: 0.2f} of test examples are in the train set."
268
+ # )
269
+ # return data
270
270
271
271
272
- def _mqar (
272
+ def multiquery_ar (
273
273
vocab_size : int ,
274
274
num_examples : int ,
275
275
input_seq_len : int ,
@@ -278,7 +278,7 @@ def _mqar(
278
278
num_kv_pairs : int = 8 ,
279
279
random_non_queries : bool = True ,
280
280
** kwargs
281
- ):
281
+ ) -> DataSegment :
282
282
assert input_seq_len % 2 == 0 , "input_seq_len must be even"
283
283
assert vocab_size > input_seq_len
284
284
assert num_kv_pairs * 4 <= input_seq_len
@@ -328,7 +328,7 @@ def _mqar(
328
328
# replace all the 0 with random values
329
329
if random_non_queries :
330
330
inputs [inputs == 0 ] = torch .randint (vocab_size , size = inputs .shape )[inputs == 0 ]
331
- return inputs , labels
331
+ return DataSegment ( inputs , labels , slices = None )
332
332
333
333
334
334
0 commit comments