Skip to content

Commit 60cdd0b

Browse files
bigningdakinggg
andauthored
add finutuning with streaming dataset example (#945)
* add convert * fix * fix convert * add jsonl * revert setup * test precommit * pre-commit * test pre-commit * v0 * review comments * temporarily trigger test * test * fix yaml * comments * comments * comments * add unit test * comments --------- Co-authored-by: Daniel King <[email protected]>
1 parent 9f10184 commit 60cdd0b

File tree

6 files changed

+166
-14
lines changed

6 files changed

+166
-14
lines changed

llmfoundry/data/finetuning/collator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def __init__(
106106

107107
def __call__(self, examples: List[Dict[str,
108108
Any]]) -> Dict[str, torch.Tensor]:
109-
for check_key in ['input_ids', 'labels', 'attention_mask']:
109+
for check_key in ['input_ids', 'labels']:
110110
if check_key not in examples[0]:
111111
raise KeyError(
112112
f'Examples returned by dataset do not include required key: {check_key}'

llmfoundry/data/finetuning/dataloader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
152152
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
153153
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
154154
batching_method=cfg.dataset.get('batching_method', 'random'),
155+
max_seq_len=cfg.dataset.max_seq_len,
155156
)
156157

157158
else:
@@ -284,6 +285,9 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
284285
'HuggingFace dataset or set `remote` to use a streaming ' +\
285286
'dataset, but both were None.'
286287
)
288+
if dataset_cfg.get('max_seq_len') is None:
289+
raise ValueError(
290+
'In the dataset config, you must set the `max_seq_len`')
287291

288292

289293
def _download_remote_hf_dataset(remote_path: str, split: str) -> str:

llmfoundry/data/finetuning/tasks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
4242

4343
import datasets as hf_datasets
4444
import huggingface_hub as hf_hub
45+
import numpy as np
4546
from composer.utils import dist
4647
from streaming import StreamingDataset
4748
from transformers import PreTrainedTokenizerBase
@@ -332,6 +333,7 @@ def __init__(self,
332333
sampling_method: str = 'balanced',
333334
sampling_granularity: int = 1,
334335
batching_method: str = 'random',
336+
max_seq_len: int = 2048,
335337
**kwargs: Any):
336338

337339
if len(kwargs) > 0:
@@ -371,10 +373,31 @@ def __init__(self,
371373
)
372374

373375
self.tokenizer = tokenizer
376+
self.max_seq_len = max_seq_len
374377

375378
# How to process a sample
376379
def __getitem__(self, idx: int) -> Dict[str, Any]:
377380
sample = super().__getitem__(idx)
381+
if 'input_ids' in sample:
382+
# Already tokenized data
383+
if isinstance(sample['input_ids'], bytes):
384+
sample['input_ids'] = np.frombuffer(
385+
sample['input_ids'],
386+
dtype=np.int64)[:self.max_seq_len].tolist().copy()
387+
sample['labels'] = np.frombuffer(
388+
sample['labels'],
389+
dtype=np.int64)[:self.max_seq_len].tolist().copy()
390+
elif isinstance(sample['input_ids'], np.ndarray):
391+
sample['input_ids'] = sample[
392+
'input_ids'][:self.max_seq_len].tolist().copy()
393+
sample['labels'] = sample['labels'][:self.max_seq_len].tolist(
394+
).copy()
395+
else:
396+
raise ValueError(
397+
f'Expect input_ids to be bytes or numpy.ndarray type, but got {type(sample["input_ids"])}'
398+
)
399+
400+
return sample
378401
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)
379402

380403

scripts/data_prep/convert_finetuning_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def main(args: Namespace) -> None:
202202
tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
203203
if args.tokenizer:
204204
tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs)
205-
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
205+
columns = {'input_ids': 'ndarray:uint32', 'labels': 'ndarray:uint32'}
206206
else:
207207
columns = {'prompt': 'str', 'response': 'str'}
208208

@@ -255,7 +255,8 @@ def main(args: Namespace) -> None:
255255
sample_to_write = {}
256256
# convert to bytes
257257
for key in columns.keys():
258-
sample_to_write[key] = np.asarray(sample[key]).tobytes()
258+
sample_to_write[key] = np.asarray(sample[key],
259+
dtype=np.uint32)
259260
out.write(sample_to_write)
260261
else:
261262
encoded_sample = {
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
max_seq_len: 512
2+
global_seed: 17
3+
4+
data_local: ./my_data
5+
data_remote: # If blank, files must be present in data_local
6+
7+
# Run Name
8+
run_name: # If left blank, will be read from env var $RUN_NAME
9+
10+
# Model
11+
model:
12+
name: hf_causal_lm
13+
pretrained_model_name_or_path: gpt2
14+
pretrained: true # false: only use the architecture; true: initialize with pretrained weights
15+
16+
# Tokenizer
17+
tokenizer:
18+
name: gpt2
19+
kwargs:
20+
model_max_length: ${max_seq_len}
21+
22+
# Dataloaders
23+
train_loader:
24+
name: finetuning
25+
dataset:
26+
############
27+
remote: ${data_remote}
28+
local: ${data_local}
29+
split: train
30+
############
31+
shuffle: true
32+
max_seq_len: ${max_seq_len}
33+
decoder_only_format: true
34+
drop_last: true
35+
num_workers: 8
36+
37+
# Optimization
38+
scheduler:
39+
name: cosine_with_warmup
40+
t_warmup: 100ba
41+
alpha_f: 0.1
42+
43+
optimizer:
44+
name: decoupled_adamw
45+
lr: 6.0e-4
46+
betas:
47+
- 0.9
48+
- 0.95
49+
eps: 1.0e-08
50+
weight_decay: 0.0
51+
52+
algorithms:
53+
gradient_clipping:
54+
clipping_type: norm
55+
clipping_threshold: 1.0
56+
57+
max_duration: 1ep
58+
eval_interval: 1
59+
eval_first: false
60+
eval_subset_num_batches: -1
61+
global_train_batch_size: 8
62+
63+
# System
64+
seed: ${global_seed}
65+
device_eval_batch_size: 8
66+
device_train_microbatch_size: 8
67+
# device_train_microbatch_size: auto
68+
precision: fp32
69+
70+
# Logging
71+
progress_bar: false
72+
log_to_console: true
73+
console_log_interval: 1ba
74+
75+
callbacks:
76+
speed_monitor:
77+
window_size: 10
78+
lr_monitor: {}
79+
memory_monitor: {}
80+
runtime_estimator: {}

tests/data/test_dataloader.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import ContextManager, Literal, Optional, Union
1313
from unittest.mock import MagicMock, patch
1414

15+
import numpy as np
1516
import pytest
1617
import torch
1718
import transformers
@@ -25,7 +26,8 @@
2526
from llmfoundry.data import build_dataloader
2627
from llmfoundry.data.finetuning.tasks import (DOWNLOADED_FT_DATASETS_DIRPATH,
2728
SUPPORTED_EXTENSIONS,
28-
is_valid_ift_example)
29+
is_valid_ift_example,
30+
tokenize_formatted_example)
2931
from llmfoundry.data.text_data import (ConcatenatedSequenceCollatorWrapper,
3032
build_text_dataloader,
3133
get_tokens_per_batch_func)
@@ -49,23 +51,51 @@ def get_abs_data_path(data_local: str):
4951
return os.path.join(os.getcwd(), data_local)
5052

5153

52-
def build_mock_ft_streaming_dataset(data_path: str, split: str):
53-
columns = {'prompt': 'str', 'response': 'str'}
54+
def build_mock_ft_streaming_dataset(
55+
data_path: str,
56+
split: str,
57+
pretokenize: bool,
58+
use_bytes: bool,
59+
tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None):
60+
if pretokenize:
61+
if use_bytes:
62+
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
63+
else:
64+
columns = {
65+
'input_ids': 'ndarray:uint32',
66+
'labels': 'ndarray:uint32'
67+
}
68+
else:
69+
columns = {'prompt': 'str', 'response': 'str'}
5470

5571
dataset = [{
5672
'prompt': 'This is just a test1',
5773
'response': 'Hello World1'
5874
}, {
5975
'prompt': 'This is just a test2',
6076
'response': 'Hello world2'
77+
}, {
78+
'prompt': 'This is just a test3',
79+
'response': 'Hello world3'
6180
}]
6281

6382
output_path = os.path.join(data_path, split)
6483

6584
with MDSWriter(columns=columns, out=output_path,
6685
compression=None) as output_writer:
6786
for sample in dataset:
68-
output_writer.write(sample)
87+
if pretokenize:
88+
sample = tokenize_formatted_example(sample, tokenizer=tokenizer)
89+
sample_to_write = {}
90+
for key in columns.keys():
91+
if use_bytes:
92+
sample_to_write[key] = np.asarray(sample[key]).tobytes()
93+
else:
94+
sample_to_write[key] = np.asarray(sample[key],
95+
dtype=np.uint32)
96+
output_writer.write(sample_to_write)
97+
else:
98+
output_writer.write(sample)
6999

70100

71101
@pytest.mark.parametrize('tokenizer_name', ['gpt2', 'facebook/opt-125m'])
@@ -517,13 +547,25 @@ def test_finetuning_dataloader_custom_split_remote(split: str):
517547
assert split in dest_arg, 'split destination should match split name'
518548

519549

520-
def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
550+
@pytest.mark.parametrize('pretokenize', [True, False])
551+
@pytest.mark.parametrize('use_bytes', [True, False])
552+
def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool,
553+
tmp_path: pathlib.Path):
521554
max_seq_len = 2048
522555

523556
remote_path = os.path.join(tmp_path, 'remote')
524557
local_path = os.path.join(tmp_path, 'local')
525558

526-
build_mock_ft_streaming_dataset(remote_path, 'train')
559+
tokenizer = build_tokenizer(
560+
tokenizer_name='gpt2',
561+
tokenizer_kwargs={'model_max_length': max_seq_len},
562+
)
563+
564+
build_mock_ft_streaming_dataset(remote_path,
565+
'train',
566+
pretokenize,
567+
use_bytes=use_bytes,
568+
tokenizer=tokenizer)
527569

528570
cfg = {
529571
'name': 'finetuning',
@@ -547,12 +589,14 @@ def test_finetuning_dataloader_streaming(tmp_path: pathlib.Path):
547589

548590
cfg = om.create(cfg)
549591

550-
tokenizer = build_tokenizer(
551-
tokenizer_name='gpt2',
552-
tokenizer_kwargs={'model_max_length': max_seq_len},
553-
)
592+
dataloader = build_finetuning_dataloader(cfg, tokenizer, 2).dataloader
554593

555-
_ = build_finetuning_dataloader(cfg, tokenizer, 4)
594+
expected_keys = ['input_ids', 'labels']
595+
for batch in dataloader:
596+
for key in expected_keys:
597+
assert key in batch
598+
assert batch[key].shape[0] == 2
599+
break
556600

557601

558602
def test_finetuning_dataloader_is_valid_ift_example():

0 commit comments

Comments
 (0)