Skip to content

Commit 194a712

Browse files
author
eyuboglu
committed
Update code with dats segments
1 parent 4ca911f commit 194a712

File tree

8 files changed

+265
-330
lines changed

8 files changed

+265
-330
lines changed

zoology/config.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
from datetime import datetime
33
from functools import partial
4-
from typing import Tuple, Union
4+
from typing import List, Tuple, Union
55

66
from pydantic import BaseModel
77

@@ -64,19 +64,21 @@ class ModuleConfig(BaseConfig):
6464
def instantiate(self, **kwargs):
6565
return import_from_str(self.name)(**kwargs, **self.kwargs)
6666

67-
68-
class DataConfig(BaseConfig):
67+
class DataSegmentConfig(BaseConfig):
68+
vocab_size: int = 8_192
69+
num_examples: int = 1_000
70+
input_seq_len: int = 64
6971
builder: FunctionConfig = None
70-
seed: int = 0
7172

72-
num_train_examples: int = 10_000
73-
num_test_examples: int = 1000
74-
input_seq_len: int = 64
75-
vocab_size: int = 8_192
73+
class DataConfig(BaseConfig):
74+
train_configs: List[DataSegmentConfig]
75+
test_configs: List[DataSegmentConfig]
7676

7777
# can pass a tuple if you want a different batch size for train and test
7878
batch_size: Union[int, Tuple[int, int]] = 32
79-
79+
80+
seed: int = 123
81+
8082
cache_dir: str = None
8183
caching: bool = True
8284
force_cache: bool = False
@@ -109,8 +111,8 @@ class LoggerConfig(BaseConfig):
109111

110112

111113
class TrainConfig(BaseConfig):
112-
data: DataConfig = DataConfig()
113-
model: ModelConfig = ModelConfig()
114+
data: DataConfig
115+
model: ModelConfig
114116
logger: LoggerConfig = LoggerConfig()
115117

116118
max_epochs: int = 100
@@ -127,5 +129,3 @@ class TrainConfig(BaseConfig):
127129
launch_id: str = None
128130
sweep_id: str = None
129131
run_id: str = "default"
130-
131-

zoology/data/ar_extrapolate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from pydantic import BaseModel
77

8-
from .utils import MultiSyntheticData, SyntheticDataSection
8+
from .utils import SyntheticDataSection
99
from .associative_recall import _mqar, _ar
1010

1111
class ARConfig(BaseModel):
@@ -24,7 +24,7 @@ def ar_extrapolate(
2424
num_test_examples: int=3_000,
2525
input_seq_len: int=64,
2626
seed: int=0,
27-
) -> MultiSyntheticData:
27+
) -> SyntheticDataSection:
2828

2929
# input seq len should be the max for all the configs
3030
assert input_seq_len == max([c.input_seq_len for c in train_configs + test_configs])

zoology/data/associative_recall.py

Lines changed: 107 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import torch
44

5-
from .utils import SyntheticData, builder_from_single
5+
from .utils import DataSegment, builder_from_single
66

77

88
def associative_recall(
@@ -54,7 +54,7 @@ def associative_recall(
5454
Warning: If potential data leakage is detected between the train and test sets.
5555
"""
5656

57-
train_inputs, train_labels = _ar(
57+
train = _ar(
5858
vocab_size=vocab_size,
5959
num_examples=num_train_examples,
6060
input_seq_len=input_seq_len,
@@ -63,7 +63,7 @@ def associative_recall(
6363
num_queries=num_queries,
6464
random_non_queries=random_non_queries
6565
)
66-
test_inputs, test_labels = _ar(
66+
test = _ar(
6767
vocab_size=vocab_size,
6868
num_examples=num_test_examples,
6969
input_seq_len=input_seq_len,
@@ -163,113 +163,113 @@ def _ar(
163163
return inputs, targets
164164

165165

166-
def multiquery_ar(
167-
vocab_size: int=8_192,
168-
num_train_examples: int=100_000,
169-
num_test_examples: int=3_000,
170-
input_seq_len: int=64,
171-
num_kv_pairs: int=4,
172-
train_power_a: float=0.01,
173-
test_power_a: float=0.01,
174-
random_non_queries: bool=True,
175-
seed: int=0,
176-
) -> SyntheticData:
177-
"""
178-
Generates synthetic data for the multi-query associative recall task as described in
179-
Arora,Eyuboglu, et al. "Zoology: Measuring and improving recall in efficient language models.".
180-
181-
Example:
182-
`multiquery_ar(vocab_size=12, num_kv_pairs=2, input_seq_len=16, random_non_queries=False)`
183-
will generate input and label sequences of the form:
166+
# def multiquery_ar(
167+
# vocab_size: int=8_192,
168+
# num_train_examples: int=100_000,
169+
# num_test_examples: int=3_000,
170+
# input_seq_len: int=64,
171+
# num_kv_pairs: int=4,
172+
# train_power_a: float=0.01,
173+
# test_power_a: float=0.01,
174+
# random_non_queries: bool=True,
175+
# seed: int=0,
176+
# ) -> SyntheticData:
177+
# """
178+
# Generates synthetic data for the multi-query associative recall task as described in
179+
# Arora,Eyuboglu, et al. "Zoology: Measuring and improving recall in efficient language models.".
180+
181+
# Example:
182+
# `multiquery_ar(vocab_size=12, num_kv_pairs=2, input_seq_len=16, random_non_queries=False)`
183+
# will generate input and label sequences of the form:
184184

185-
Key Val Key Val Query Query
186-
Inputs: 2 8 4 7 0 0 4 0 0 0 0 0 2 0 0
187-
Labels: -100 -100 -100 -100 -100 -100 7 -100 -100 -100 -100 -100 8 -100 -100
185+
# Key Val Key Val Query Query
186+
# Inputs: 2 8 4 7 0 0 4 0 0 0 0 0 2 0 0
187+
# Labels: -100 -100 -100 -100 -100 -100 7 -100 -100 -100 -100 -100 8 -100 -100
188188

189-
The -100 labels are ignored by the loss function and metrics.
189+
# The -100 labels are ignored by the loss function and metrics.
190190

191-
We include one important note on the power law distribution. In real language data,
192-
the gap between repeated bigrams follows a power law. Intuitively, if the bigram
193-
"common buzzard" appears in text, the probability of the bigram appearing again
194-
drops the further away from the orginal mention we are. In our synthetic, we can
195-
control this with the power law parameters `train_power_a` and `test_power_a`.
196-
Setting these to 1.0 will result in a uniform distribution. You can visualize the
197-
distribution with the following code:
198-
```
199-
space = 100
200-
power_a = 0.01
201-
p = power_a * np.arange(1, space + 1) ** (power_a-1)
202-
p = p / p.sum()
203-
plt.plot(p)
204-
```
205-
206-
Args:
207-
vocab_size (int): The size of the vocabulary. As discussed in the Zoology
208-
paper, large vocabulary sizes (>1k) can be important for highlighting
209-
differences between model architectures. Defaults to 8_192.
210-
num_train_examples (int): The number of training examples to generate. Defaults
211-
to 100_000.
212-
num_test_examples (int): The number of test examples to generate. Defaults to
213-
3_000.
214-
input_seq_len (int): The length of the input sequence. Defaults to 64. In
215-
In Figure 2 of the Zoology paper, we vary the input sequence length from
216-
64 to 512 and the number of key-value pairs from 4 to 64.
217-
seed (int): The seed for the random number generator.
218-
num_kv_pairs (int): The number of key-value pairs.
219-
train_power_a (float, optional): The power for the power law distribution for
220-
training data. Defaults to 0.01.
221-
test_power_a (float, optional): The power for the power law distribution for
222-
test data. Defaults to 0.01.
223-
random_non_queries (bool, optional): If True, replace all the 0's (as in the
224-
example above) with random values in the input. Defaults to True.
225-
226-
Returns:
227-
SyntheticData: A SyntheticData object containing the generated train and test
228-
inputs and labels.
229-
230-
Raises:
231-
Warning: If potential data leakage is detected between the train and test sets.
232-
"""
233-
234-
train_inputs, train_labels = _mqar(
235-
vocab_size=vocab_size,
236-
num_examples=num_train_examples,
237-
input_seq_len=input_seq_len,
238-
seed=seed,
239-
power_a=train_power_a,
240-
num_kv_pairs=num_kv_pairs,
241-
random_non_queries=random_non_queries
242-
)
243-
test_inputs, test_labels = _mqar(
244-
vocab_size=vocab_size,
245-
num_examples=num_test_examples,
246-
input_seq_len=input_seq_len,
247-
seed=seed + 10, # different seed for test set
248-
power_a=test_power_a,
249-
num_kv_pairs=num_kv_pairs,
250-
random_non_queries=random_non_queries
251-
)
252-
253-
data = SyntheticData(
254-
train_inputs=train_inputs,
255-
train_labels=train_labels,
256-
test_inputs=test_inputs,
257-
test_labels=test_labels,
258-
)
259-
260-
# check for data leakage:
261-
train_set = set([" ".join(map(str, x)) for x in data.train_inputs.tolist()])
262-
test_set = set([" ".join(map(str, x)) for x in data.test_inputs.tolist()])
263-
frac_test_in_train = 1 - (len(test_set - train_set) / len(test_set))
264-
if frac_test_in_train > 0.001:
265-
print(
266-
"WARNING: Potential data leakage detected. "
267-
f"{frac_test_in_train: 0.2f} of test examples are in the train set."
268-
)
269-
return data
191+
# We include one important note on the power law distribution. In real language data,
192+
# the gap between repeated bigrams follows a power law. Intuitively, if the bigram
193+
# "common buzzard" appears in text, the probability of the bigram appearing again
194+
# drops the further away from the orginal mention we are. In our synthetic, we can
195+
# control this with the power law parameters `train_power_a` and `test_power_a`.
196+
# Setting these to 1.0 will result in a uniform distribution. You can visualize the
197+
# distribution with the following code:
198+
# ```
199+
# space = 100
200+
# power_a = 0.01
201+
# p = power_a * np.arange(1, space + 1) ** (power_a-1)
202+
# p = p / p.sum()
203+
# plt.plot(p)
204+
# ```
205+
206+
# Args:
207+
# vocab_size (int): The size of the vocabulary. As discussed in the Zoology
208+
# paper, large vocabulary sizes (>1k) can be important for highlighting
209+
# differences between model architectures. Defaults to 8_192.
210+
# num_train_examples (int): The number of training examples to generate. Defaults
211+
# to 100_000.
212+
# num_test_examples (int): The number of test examples to generate. Defaults to
213+
# 3_000.
214+
# input_seq_len (int): The length of the input sequence. Defaults to 64. In
215+
# In Figure 2 of the Zoology paper, we vary the input sequence length from
216+
# 64 to 512 and the number of key-value pairs from 4 to 64.
217+
# seed (int): The seed for the random number generator.
218+
# num_kv_pairs (int): The number of key-value pairs.
219+
# train_power_a (float, optional): The power for the power law distribution for
220+
# training data. Defaults to 0.01.
221+
# test_power_a (float, optional): The power for the power law distribution for
222+
# test data. Defaults to 0.01.
223+
# random_non_queries (bool, optional): If True, replace all the 0's (as in the
224+
# example above) with random values in the input. Defaults to True.
225+
226+
# Returns:
227+
# SyntheticData: A SyntheticData object containing the generated train and test
228+
# inputs and labels.
229+
230+
# Raises:
231+
# Warning: If potential data leakage is detected between the train and test sets.
232+
# """
233+
234+
# train_inputs, train_labels = _mqar(
235+
# vocab_size=vocab_size,
236+
# num_examples=num_train_examples,
237+
# input_seq_len=input_seq_len,
238+
# seed=seed,
239+
# power_a=train_power_a,
240+
# num_kv_pairs=num_kv_pairs,
241+
# random_non_queries=random_non_queries
242+
# )
243+
# test_inputs, test_labels = _mqar(
244+
# vocab_size=vocab_size,
245+
# num_examples=num_test_examples,
246+
# input_seq_len=input_seq_len,
247+
# seed=seed + 10, # different seed for test set
248+
# power_a=test_power_a,
249+
# num_kv_pairs=num_kv_pairs,
250+
# random_non_queries=random_non_queries
251+
# )
252+
253+
# data = SyntheticData(
254+
# train_inputs=train_inputs,
255+
# train_labels=train_labels,
256+
# test_inputs=test_inputs,
257+
# test_labels=test_labels,
258+
# )
259+
260+
# # check for data leakage:
261+
# train_set = set([" ".join(map(str, x)) for x in data.train_inputs.tolist()])
262+
# test_set = set([" ".join(map(str, x)) for x in data.test_inputs.tolist()])
263+
# frac_test_in_train = 1 - (len(test_set - train_set) / len(test_set))
264+
# if frac_test_in_train > 0.001:
265+
# print(
266+
# "WARNING: Potential data leakage detected. "
267+
# f"{frac_test_in_train: 0.2f} of test examples are in the train set."
268+
# )
269+
# return data
270270

271271

272-
def _mqar(
272+
def multiquery_ar(
273273
vocab_size: int,
274274
num_examples: int,
275275
input_seq_len: int,
@@ -278,7 +278,7 @@ def _mqar(
278278
num_kv_pairs: int=8,
279279
random_non_queries: bool=True,
280280
**kwargs
281-
):
281+
) -> DataSegment:
282282
assert input_seq_len % 2 == 0, "input_seq_len must be even"
283283
assert vocab_size > input_seq_len
284284
assert num_kv_pairs * 4 <= input_seq_len
@@ -328,7 +328,7 @@ def _mqar(
328328
# replace all the 0 with random values
329329
if random_non_queries:
330330
inputs[inputs == 0] = torch.randint(vocab_size, size=inputs.shape)[inputs == 0]
331-
return inputs, labels
331+
return DataSegment(inputs, labels, slices=None)
332332

333333

334334

0 commit comments

Comments
 (0)