Skip to content

Commit 1a69081

Browse files
authored
[Megatron dataset] Support loading megatron dataset (#6489)
* Adapt to Megatron * fix_print_dataset * fix BlendableDataset * fix BlendableDataset * fix skip_warmup * fix * fix * fix * fix * fix * fix * cache fix * make new dataset * fix loss mask * fix model_zoo/gpt * fix model_zoo/gpt * fix model_zoo/gpt * fix gpt test * fix legacy * fix legacy * hf_model * remove legacy * merge develop gpt * fix model_zoo/gpt for megatron * merge develop * resolve conflict * fix check_rank_flag for data_cache_path * fix check_rank_flag for data_cache_path * remove hcg * fix model_zoo/gpt eval
1 parent d012c87 commit 1a69081

File tree

15 files changed

+2728
-1060
lines changed

15 files changed

+2728
-1060
lines changed

llm/gpt-3/dataset.py

Lines changed: 0 additions & 444 deletions
This file was deleted.

llm/gpt-3/run_pretrain.py

Lines changed: 44 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
),
5050
}
5151

52-
from dataset import GPTDataset, get_train_valid_test_split_
52+
from paddlenlp.data.causal_dataset import build_train_valid_test_datasets, print_rank_0
5353

5454

5555
def add_start_docstrings(*docstr):
@@ -86,7 +86,6 @@ class DataArguments:
8686
input_dir: str = field(
8787
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
8888
)
89-
cache_prefix: str = field(default=None, metadata={"help": "The prefix of the cached dataset."})
9089
split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."})
9190

9291
max_seq_length: int = field(
@@ -101,6 +100,13 @@ class DataArguments:
101100
metadata={"help": "Use share folder for data dir and output dir on multi machine."},
102101
)
103102

103+
data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."})
104+
skip_warmup: bool = field(
105+
default=True,
106+
metadata={"help": "Whether to skip the warmup process of mmap files."},
107+
)
108+
data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."})
109+
104110

105111
@dataclass
106112
class ModelArguments:
@@ -143,7 +149,7 @@ def create_pretrained_dataset(
143149
tokenizer,
144150
):
145151

146-
train_valid_test_num_samples = [
152+
train_val_test_num_samples = [
147153
training_args.per_device_train_batch_size
148154
* training_args.dataset_world_size
149155
* training_args.max_steps
@@ -155,72 +161,50 @@ def create_pretrained_dataset(
155161
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
156162
]
157163

158-
input_prefix = data_file[0]
159-
160-
for suffix in ["_ids.npy", "_idx.npz"]:
161-
if not os.path.isfile(input_prefix + suffix):
162-
raise ValueError("File Not found, %s" % (input_prefix + suffix))
163-
164-
sample_ids = np.load(input_prefix + "_ids.npy", mmap_mode="r", allow_pickle=True)
165-
# All documment ids, extend as 1-D array.
166-
167-
process_data = np.load(input_prefix + "_idx.npz")
168-
# The len(sample_lens) num of docs
169-
# The sum(sample_lens) should equal len(sample_ids)
170-
sample_lens = process_data["lens"]
171-
172-
splits = get_train_valid_test_split_(data_args.split, len(sample_lens))
173-
assert len(sample_lens) >= splits[-1], "The document nums should larger than max of splits, but %s < %s" % (
174-
len(sample_lens),
175-
splits[-1],
164+
print_rank_0(" > datasets target sizes (minimum size):")
165+
print_rank_0(" train: {}".format(train_val_test_num_samples[0]))
166+
print_rank_0(" validation: {}".format(train_val_test_num_samples[1]))
167+
print_rank_0(" test: {}".format(train_val_test_num_samples[2]))
168+
169+
# Build the datasets.
170+
train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets(
171+
data_prefix=data_file,
172+
data_impl=data_args.data_impl,
173+
splits_string=data_args.split,
174+
train_val_test_num_samples=train_val_test_num_samples,
175+
seq_length=data_args.max_seq_length,
176+
seed=training_args.seed,
177+
skip_warmup=data_args.skip_warmup,
178+
data_cache_path=data_args.data_cache,
176179
)
177180

178181
def print_dataset(data, mode="train"):
179182
logger.info(f"Sample data for {mode} mode")
180-
input_ids, loss_mask, attention_mask, position_ids, labels = data
183+
# input_ids, loss_mask, attention_mask, position_ids, labels = data
184+
input_ids = data["text"]
185+
181186
logger.info(tokenizer._decode(input_ids))
182-
# logger.info(tokenizer._decode(labels))
183-
# logger.info(tokenizer.convert_ids_to_tokens(input_ids))
184-
185-
def build_dataset(index, name):
186-
dataset = GPTDataset(
187-
file_prefix=os.path.join(data_args.cache_prefix, os.path.basename(input_prefix)),
188-
build_data_file=training_args.local_process_index == 0,
189-
micro_batch_size=training_args.per_device_train_batch_size
190-
if name == "train"
191-
else training_args.per_device_eval_batch_size,
192-
name="gpt_" + name,
193-
max_seq_len=data_args.max_seq_length,
194-
num_samples=train_valid_test_num_samples[index],
195-
documents=np.arange(splits[index], splits[index + 1]),
196-
sample_ids=sample_ids,
197-
sample_lens=sample_lens,
198-
eos_id=tokenizer.eos_token_id,
199-
seed=training_args.seed,
200-
)
201-
print_dataset(dataset[0], name)
202-
return dataset
203187

204188
from paddlenlp.data import Stack
205189

206190
def _collate_data(data, stack_fn=Stack()):
207-
num_fields = len(data[0])
208-
out = [None] * num_fields
209-
# 0:input_ids, 1:loss_mask, 2:attention_mask, 3:position_ids, 4:labels
210-
for i in (0, 1, 2, 3, 4):
211-
out[i] = stack_fn([x[i] for x in data])
191+
tokens_ = stack_fn(x["text"] for x in data)
192+
193+
labels = tokens_[:, 1:]
194+
tokens = tokens_[:, :-1]
195+
196+
# Attention mask.
197+
attention_mask = paddle.ones(tokens.shape, dtype=paddle.int64)
212198

213199
return {
214-
"input_ids": out[0],
215-
"attention_mask": out[2],
216-
"labels": out[4],
200+
"input_ids": tokens,
201+
"attention_mask": attention_mask,
202+
"labels": labels,
217203
}
218204

219-
# Note, data should be broardcast to all devices.
220-
# for train, valid, test, the distinct data num is data_world_size
221-
train_dataset = build_dataset(0, "train")
222-
valid_dataset = build_dataset(1, "valid")
223-
test_dataset = build_dataset(2, "test")
205+
print_dataset(train_dataset[0])
206+
print_dataset(valid_dataset[0])
207+
print_dataset(test_dataset[0])
224208

225209
return train_dataset, valid_dataset, test_dataset, _collate_data
226210

@@ -233,9 +217,10 @@ def get_train_data_file(args):
233217
files = [
234218
os.path.join(args.input_dir, f)
235219
for f in os.listdir(args.input_dir)
236-
if (os.path.isfile(os.path.join(args.input_dir, f)) and "_idx.npz" in str(f))
220+
if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f)))
237221
]
238222
files = [x.replace("_idx.npz", "") for x in files]
223+
files = [x.replace(".idx", "") for x in files]
239224

240225
if len(files) > 1:
241226
ret = []
@@ -333,10 +318,8 @@ def main():
333318
if model_args.tokenizer_name_or_path is None:
334319
model_args.tokenizer_name_or_path = model_args.model_name_or_path
335320

336-
if data_args.cache_prefix is None:
337-
data_args.cache_prefix = data_args.input_dir
338-
else:
339-
os.makedirs(data_args.cache_prefix, exist_ok=True)
321+
if data_args.data_cache is not None:
322+
os.makedirs(data_args.data_cache, exist_ok=True)
340323

341324
set_seed(training_args)
342325
paddle.set_device(training_args.device)

llm/llama/dataset.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

llm/llama/run_pretrain.py

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@
5050
),
5151
}
5252

53-
from dataset import GPTDataset, get_train_valid_test_split_
5453
from fused_layers import mock_layers
5554
from modeling_pp import LlamaForCausalLMPipe
5655

56+
from paddlenlp.data.causal_dataset import build_train_valid_test_datasets, print_rank_0
57+
5758

5859
def add_start_docstrings(*docstr):
5960
def docstring_decorator(fn):
@@ -95,7 +96,6 @@ class DataArguments:
9596
input_dir: str = field(
9697
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
9798
)
98-
cache_prefix: str = field(default=None, metadata={"help": "The prefix of the cached dataset."})
9999
split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."})
100100

101101
max_seq_length: int = field(
@@ -111,6 +111,13 @@ class DataArguments:
111111
)
112112
train_data_size: int = field(default=-1, metadata={"help": "Number of dataset for training"})
113113

114+
data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."})
115+
skip_warmup: bool = field(
116+
default=True,
117+
metadata={"help": "Whether to skip the warmup process of mmap files."},
118+
)
119+
data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."})
120+
114121

115122
@dataclass
116123
class ModelArguments:
@@ -200,7 +207,7 @@ def create_pretrained_dataset(
200207
tokenizer,
201208
):
202209

203-
train_valid_test_num_samples = [
210+
train_val_test_num_samples = [
204211
training_args.per_device_train_batch_size
205212
* training_args.dataset_world_size
206213
* training_args.max_steps
@@ -212,74 +219,46 @@ def create_pretrained_dataset(
212219
training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters,
213220
]
214221

215-
input_prefix = data_file[0]
216-
217-
for suffix in ["_ids.npy", "_idx.npz"]:
218-
if not os.path.isfile(input_prefix + suffix):
219-
raise ValueError("File Not found, %s" % (input_prefix + suffix))
220-
221-
sample_ids = np.load(input_prefix + "_ids.npy", mmap_mode="r", allow_pickle=True)
222-
# All documment ids, extend as 1-D array.
223-
224-
process_data = np.load(input_prefix + "_idx.npz")
225-
# The len(sample_lens) num of docs
226-
# The sum(sample_lens) should equal len(sample_ids)
227-
sample_lens = process_data["lens"]
228-
229-
splits = get_train_valid_test_split_(data_args.split, len(sample_lens))
230-
assert len(sample_lens) >= splits[-1], "The document nums should larger than max of splits, but %s < %s" % (
231-
len(sample_lens),
232-
splits[-1],
222+
print_rank_0(" > datasets target sizes (minimum size):")
223+
print_rank_0(" train: {}".format(train_val_test_num_samples[0]))
224+
print_rank_0(" validation: {}".format(train_val_test_num_samples[1]))
225+
print_rank_0(" test: {}".format(train_val_test_num_samples[2]))
226+
227+
# Build the datasets.
228+
train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets(
229+
data_prefix=data_file,
230+
data_impl=data_args.data_impl,
231+
splits_string=data_args.split,
232+
train_val_test_num_samples=train_val_test_num_samples,
233+
seq_length=data_args.max_seq_length,
234+
seed=training_args.seed,
235+
skip_warmup=data_args.skip_warmup,
236+
data_cache_path=data_args.data_cache,
233237
)
234238

235239
def print_dataset(data, mode="train"):
236240
logger.info(f"Sample data for {mode} mode")
237-
input_ids, loss_mask, attention_mask, position_ids, labels = data
241+
# input_ids, loss_mask, attention_mask, position_ids, labels = data
242+
input_ids = data["text"]
243+
238244
logger.info(tokenizer._decode(input_ids))
239-
# logger.info(tokenizer._decode(labels))
240-
# logger.info(tokenizer.convert_ids_to_tokens(input_ids))
241-
242-
def build_dataset(index, name):
243-
dataset = GPTDataset(
244-
file_prefix=os.path.join(data_args.cache_prefix, os.path.basename(input_prefix)),
245-
build_data_file=training_args.local_process_index == 0,
246-
micro_batch_size=training_args.per_device_train_batch_size
247-
if name == "train"
248-
else training_args.per_device_eval_batch_size,
249-
name="gpt_" + name,
250-
max_seq_len=data_args.max_seq_length,
251-
num_samples=train_valid_test_num_samples[index],
252-
documents=np.arange(splits[index], splits[index + 1]),
253-
sample_ids=sample_ids,
254-
sample_lens=sample_lens,
255-
eos_id=tokenizer.eos_token_id,
256-
seed=training_args.seed,
257-
)
258-
print_dataset(dataset[0], name)
259-
return dataset
260245

261246
from paddlenlp.data import Stack
262247

263248
def _collate_data(data, stack_fn=Stack()):
264-
num_fields = len(data[0])
265-
out = [None] * num_fields
266-
# 0:input_ids, 1:loss_mask, 2:attention_mask, 3:position_ids, 4:labels
267-
for i in (0, 1, 2, 3, 4):
268-
out[i] = stack_fn([x[i] for x in data])
249+
tokens_ = stack_fn(x["text"] for x in data)
250+
251+
labels = tokens_[:, 1:]
252+
tokens = tokens_[:, :-1]
269253

270254
return {
271-
"input_ids": out[0],
272-
# "token_type_ids": out[1],
273-
# "attention_mask": out[2],
274-
# "loss_mask": out[3],
275-
"labels": out[4],
255+
"input_ids": tokens,
256+
"labels": labels,
276257
}
277258

278-
# Note, data should be broardcast to all devices.
279-
# for train, valid, test, the distinct data num is data_world_size
280-
train_dataset = build_dataset(0, "train")
281-
valid_dataset = build_dataset(1, "valid")
282-
test_dataset = build_dataset(2, "test")
259+
print_dataset(train_dataset[0], "train")
260+
print_dataset(valid_dataset[0], "valid")
261+
print_dataset(test_dataset[0], "test")
283262

284263
return train_dataset, valid_dataset, test_dataset, _collate_data
285264

@@ -292,9 +271,10 @@ def get_train_data_file(args):
292271
files = [
293272
os.path.join(args.input_dir, f)
294273
for f in os.listdir(args.input_dir)
295-
if (os.path.isfile(os.path.join(args.input_dir, f)) and "_idx.npz" in str(f))
274+
if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f)))
296275
]
297276
files = [x.replace("_idx.npz", "") for x in files]
277+
files = [x.replace(".idx", "") for x in files] # add
298278

299279
if len(files) > 1:
300280
ret = []
@@ -396,10 +376,8 @@ def main():
396376
if model_args.tokenizer_name_or_path is None:
397377
model_args.tokenizer_name_or_path = model_args.model_name_or_path
398378

399-
if data_args.cache_prefix is None:
400-
data_args.cache_prefix = data_args.input_dir
401-
else:
402-
os.makedirs(data_args.cache_prefix, exist_ok=True)
379+
if data_args.data_cache is not None:
380+
os.makedirs(data_args.data_cache, exist_ok=True)
403381

404382
set_seed(training_args)
405383
paddle.set_device(training_args.device)

llm/llama/run_trainer.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
set -x
1616
unset CUDA_VISIBLE_DEVICES
17-
1817
task_name="llama_hybid"
1918
rm -rf output/$task_name/
2019
rm -rf "output/$task_name""_log"
@@ -56,4 +55,5 @@ python -u -m paddle.distributed.launch \
5655
--recompute 1 \
5756
--do_train \
5857
--do_eval \
59-
--device "gpu"
58+
--device "gpu" \
59+
--data_impl "mmap"

0 commit comments

Comments
 (0)