Skip to content

Commit 7bc9905

Browse files
loubnabnllvwerraLoubna ben allal
authored andcommitted
CodeParrot data pretokenization (huggingface#16932)
* add pretokenization arguments * add pretokenization script * add support for pretokenized data * reformat code * fix run command for training * fix model call from config * remove a package * add comments on pretokenization in the readme * remove explicit parallelization Co-authored-by: Leandro von Werra <[email protected]> * update readme Co-authored-by: Leandro von Werra <[email protected]> * update readme -remove username Co-authored-by: Leandro von Werra <[email protected]> * update readme -remove username Co-authored-by: Leandro von Werra <[email protected]> * keep data parallelization * reformat code * reformat code * update readme * reformat code * Update examples/research_projects/codeparrot/README.md Co-authored-by: Leandro von Werra <[email protected]> Co-authored-by: Leandro von Werra <[email protected]> Co-authored-by: Loubna ben allal <[email protected]>
1 parent 098e2c4 commit 7bc9905

File tree

5 files changed

+120
-26
lines changed

5 files changed

+120
-26
lines changed

examples/research_projects/codeparrot/README.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ python scripts/preprocessing.py \
6060
```
6161
During preprocessing the dataset is downloaded and stored locally as well as caches of the computations. Make sure you have more than 500GB free disk space to execute it.
6262

63+
### Pretokenization
64+
The tokenization of the data might be slow during the training especially for small models. We provide code to pretokenize the data beforehand in `scripts/pretokenizing.py`, but this step is optional. The dataset is downloaded and stored locally and the tokenized data is pushed to the hub. The tokenized clean [train](https://huggingface.co/datasets/loubnabnl/tokenized-codeparrot-train) and [validation](https://huggingface.co/datasets/loubnabnl/tokenized-codeparrot-valid) datasets are available if you want to use them directly.
65+
66+
To execute the pretokenization, for the clean train data for instance, run the following command:
67+
```bash
68+
python scripts/pretokenizing.py \
69+
--dataset_name lvwerra/codeparrot-clean-train \
70+
--tokenized_data_repo tokenized-codeparrot-train
71+
```
72+
6373
## Tokenizer
6474
Before training a new model for code we create a new tokenizer that is efficient at code tokenization. To train the tokenizer you can run the following command:
6575
```bash
@@ -82,7 +92,8 @@ python scripts/initialize_model.py \
8292
```
8393
This will initialize a new model with the architecture and configuration of `gpt2-large` and use the tokenizer to appropriately size the input embeddings. Finally, the initilaized model is pushed the the hub.
8494

85-
Now that the dataset, tokenizer, and model are ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
95+
We can either pass the name of a text dataset or a pretokenized dataset which speeds up training a bit.
96+
Now that the tokenizer and model are also ready we can start training the model. The main training script is built with `accelerate` to scale across a wide range of platforms and infrastructure scales. We train two models with [110M](https://huggingface.co/lvwerra/codeparrot-small/) and [1.5B](https://huggingface.co/lvwerra/codeparrot/) parameters for 25-30B tokens on a 16xA100 (40GB) machine which takes 1 day and 1 week, respectively.
8697

8798
First you need to configure `accelerate` and login to Weights & Biases:
8899

@@ -94,7 +105,7 @@ wandb login
94105
Note that during the `accelerate` configuration we enabled FP16. Then to train the large model you can run
95106

96107
```bash
97-
python scripts/codeparrot_training.py
108+
accelerate launch scripts/codeparrot_training.py
98109
```
99110

100111
If you want to train the small model you need to make some modifications:

examples/research_projects/codeparrot/scripts/arguments.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,10 @@ class TrainingArguments:
99
"""
1010

1111
model_ckpt: Optional[str] = field(
12-
default="lvwerra/codeparrot",
13-
metadata={"help": "Model name or path of model to be trained."},
12+
default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be trained."}
1413
)
1514
save_dir: Optional[str] = field(
16-
default="./",
17-
metadata={"help": "Save dir where model repo is cloned and models updates are saved to."},
15+
default="./", metadata={"help": "Save dir where model repo is cloned and models updates are saved to."}
1816
)
1917
dataset_name_train: Optional[str] = field(
2018
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path of training dataset."}
@@ -39,7 +37,7 @@ class TrainingArguments:
3937
gradient_checkpointing: Optional[bool] = field(
4038
default=True, metadata={"help": "Use gradient checkpointing to reduce memory footprint."}
4139
)
42-
max_train_steps: Optional[int] = field(default=50_000, metadata={"help": "Maximum number of training steps."})
40+
max_train_steps: Optional[int] = field(default=50000, metadata={"help": "Maximum number of training steps."})
4341
max_eval_steps: Optional[int] = field(
4442
default=-1, metadata={"help": "Maximum number of evaluation steps. If -1 the full dataset is evaluated."}
4543
)
@@ -50,9 +48,9 @@ class TrainingArguments:
5048
metadata={"help": "Interval to save checkpoints. Measured as number of forward passes not training steps."},
5149
)
5250
resume_from_checkpoint: Optional[str] = field(
53-
default=None,
54-
metadata={"help": "States path if the training should continue from a checkpoint folder."},
51+
default=None, metadata={"help": "States path if the training should continue from a checkpoint folder."}
5552
)
53+
tokenized: Optional[bool] = field(default=False, metadata={"help": "If True the data is pretokenized."})
5654

5755

5856
@dataclass
@@ -62,8 +60,7 @@ class EvaluationArguments:
6260
"""
6361

6462
model_ckpt: Optional[str] = field(
65-
default="lvwerra/codeparrot",
66-
metadata={"help": "Model name or path of model to be evaluated."},
63+
default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
6764
)
6865
dataset_name: Optional[str] = field(
6966
default="lvwerra/codeparrot-clean-valid", metadata={"help": "Name or path of validation dataset."}
@@ -83,8 +80,7 @@ class HumanEvalArguments:
8380
"""
8481

8582
model_ckpt: Optional[str] = field(
86-
default="lvwerra/codeparrot",
87-
metadata={"help": "Model name or path of model to be evaluated."},
83+
default="lvwerra/codeparrot", metadata={"help": "Model name or path of model to be evaluated."}
8884
)
8985
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
9086
num_tasks: Optional[int] = field(
@@ -170,30 +166,46 @@ class TokenizerTrainingArguments:
170166
"""
171167

172168
base_tokenizer: Optional[str] = field(
173-
default="gpt2",
174-
metadata={"help": "Base tokenizer to build new tokenizer from."},
169+
default="gpt2", metadata={"help": "Base tokenizer to build new tokenizer from."}
175170
)
176171
dataset_name: Optional[str] = field(
177172
default="transformersbook/codeparrot-train", metadata={"help": "Dataset to train tokenizer on."}
178173
)
179174
text_column: Optional[str] = field(default="content", metadata={"help": "Column containing text data to process."})
180-
vocab_size: Optional[int] = field(default=200000, metadata={"help": "Number of examples to train tokenizer on."})
175+
vocab_size: Optional[int] = field(default=200_000, metadata={"help": "Number of examples to train tokenizer on."})
181176
n_examples: Optional[int] = field(
182177
default=32768, metadata={"help": "Number of examples to train the tokenizer on."}
183178
)
184179
tokenizer_name: Optional[str] = field(default="codeparrot", metadata={"help": "Name of new tokenizer."})
185180
push_to_hub: Optional[bool] = field(default=True, metadata={"help": "Push saved tokenizer to the hub."})
186181

187182

183+
@dataclass
184+
class PretokenizationArguments:
185+
"""
186+
Configuration for data pretokenization.
187+
"""
188+
189+
tokenizer_dir: Optional[str] = field(
190+
default="lvwerra/codeparrot", metadata={"help": "Name or path to the tokenizer."}
191+
)
192+
dataset_name: Optional[str] = field(
193+
default="lvwerra/codeparrot-clean-train", metadata={"help": "Name or path to the dataset to pretokenize."}
194+
)
195+
tokenized_data_repo: Optional[str] = field(
196+
default="tokenized-codeparrot-train", metadata={"help": "Repo name of the pretokenized data."}
197+
)
198+
num_workers: Optional[int] = field(default=None, metadata={"help": "Number of workers used for code evaluation."})
199+
200+
188201
@dataclass
189202
class InitializationArguments:
190203
"""
191204
Configuration for initializing new model.
192205
"""
193206

194207
config_name: Optional[str] = field(
195-
default="gpt2-large",
196-
metadata={"help": "Configuration to use for model initialization."},
208+
default="gpt2-large", metadata={"help": "Configuration to use for model initialization."}
197209
)
198210
tokenizer_name: Optional[str] = field(
199211
default="lvwerra/codeparrot", metadata={"help": "Tokenizer attached to model."}

examples/research_projects/codeparrot/scripts/codeparrot_training.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,45 @@ class ConstantLengthDataset(IterableDataset):
2727
seq_length (int): Length of token sequences to return.
2828
num_of_sequences: Number of token sequences to keep in buffer.
2929
chars_per_token: Number of characters per token used to estimate number of tokens in text buffer.
30+
tokenized: If true we use a pretokenized dataset.
3031
"""
3132

3233
def __init__(
33-
self, tokenizer, dataset, infinite=False, seq_length=1024, num_of_sequences=1024, chars_per_token=3.6
34+
self,
35+
tokenizer,
36+
dataset,
37+
infinite=False,
38+
seq_length=1024,
39+
num_of_sequences=1024,
40+
chars_per_token=3.6,
41+
tokenized=False,
3442
):
3543
self.tokenizer = tokenizer
3644
self.concat_token_id = tokenizer.bos_token_id
3745
self.dataset = dataset
3846
self.seq_length = seq_length
39-
self.input_characters = seq_length * chars_per_token * num_of_sequences
4047
self.epoch = 0
4148
self.infinite = infinite
4249
self.current_size = 0
50+
self.tokenized = tokenized
51+
52+
if self.tokenized:
53+
self.max_buffer_size = seq_length * num_of_sequences
54+
self.content_field = "input_ids"
55+
else:
56+
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
57+
self.content_field = "content"
4358

4459
def __iter__(self):
4560
iterator = iter(self.dataset)
4661
more_examples = True
4762
while more_examples:
4863
buffer, buffer_len = [], 0
4964
while True:
50-
if buffer_len >= self.input_characters:
65+
if buffer_len >= self.max_buffer_size:
5166
break
5267
try:
53-
buffer.append(next(iterator)["content"])
68+
buffer.append(next(iterator)[self.content_field])
5469
buffer_len += len(buffer[-1])
5570
except StopIteration:
5671
if self.infinite:
@@ -60,7 +75,10 @@ def __iter__(self):
6075
else:
6176
more_examples = False
6277
break
63-
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
78+
if self.tokenized:
79+
tokenized_inputs = buffer
80+
else:
81+
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
6482
all_token_ids = []
6583
for tokenized_input in tokenized_inputs:
6684
all_token_ids.extend(tokenized_input + [self.concat_token_id])
@@ -102,8 +120,12 @@ def create_dataloaders(args):
102120
train_data = load_dataset(args.dataset_name_train, split="train", **ds_kwargs)
103121
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
104122
valid_data = load_dataset(args.dataset_name_valid, split="train", **ds_kwargs)
105-
train_dataset = ConstantLengthDataset(tokenizer, train_data, infinite=True, seq_length=args.seq_length)
106-
valid_dataset = ConstantLengthDataset(tokenizer, valid_data, infinite=False, seq_length=args.seq_length)
123+
train_dataset = ConstantLengthDataset(
124+
tokenizer, train_data, infinite=True, seq_length=args.seq_length, tokenized=args.tokenized
125+
)
126+
valid_dataset = ConstantLengthDataset(
127+
tokenizer, valid_data, infinite=False, seq_length=args.seq_length, tokenized=args.tokenized
128+
)
107129
train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size)
108130
eval_dataloader = DataLoader(valid_dataset, batch_size=args.valid_batch_size)
109131
return train_dataloader, eval_dataloader

examples/research_projects/codeparrot/scripts/initialize_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
config = AutoConfig.from_pretrained(args.config_name, **config_kwargs)
1717

1818
# Initialize new model with config
19-
model = AutoModelForCausalLM(config)
19+
model = AutoModelForCausalLM.from_config(config)
2020

2121
# Save model to the hub
2222
model.save_pretrained(args.model_name, push_to_hub=args.push_to_hub)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import multiprocessing
2+
import time
3+
4+
from datasets import load_dataset
5+
6+
from arguments import PretokenizationArguments
7+
from transformers import AutoTokenizer, HfArgumentParser
8+
9+
10+
def tokenize(example):
11+
output = dict()
12+
output["input_ids"] = tokenizer(example["content"], truncation=False)["input_ids"]
13+
output["ratio_char_token"] = len(example["content"]) / len(output["input_ids"])
14+
return output
15+
16+
17+
parser = HfArgumentParser(PretokenizationArguments)
18+
args = parser.parse_args()
19+
if args.num_workers is None:
20+
args.num_workers = multiprocessing.cpu_count()
21+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
22+
23+
t_start = time.time()
24+
ds = load_dataset(args.dataset_name, split="train")
25+
print(f"Dataset loaded in {time.time()-t_start:.2f}s")
26+
27+
t_start = time.time()
28+
ds = ds.map(
29+
tokenize,
30+
num_proc=args.num_workers,
31+
remove_columns=[
32+
"repo_name",
33+
"path",
34+
"copies",
35+
"size",
36+
"content",
37+
"license",
38+
"hash",
39+
"line_mean",
40+
"line_max",
41+
"alpha_frac",
42+
"autogenerated",
43+
],
44+
)
45+
print(f"Dataset tokenized in {time.time()-t_start:.2f}s")
46+
47+
t_start = time.time()
48+
ds.push_to_hub(args.tokenized_data_repo)
49+
print(f"Data pushed to the hub in {time.time()-t_start:.2f}s")

0 commit comments

Comments
 (0)