5050 ),
5151}
5252
53- from dataset import GPTDataset , get_train_valid_test_split_
5453from fused_layers import mock_layers
5554from modeling_pp import LlamaForCausalLMPipe
5655
56+ from paddlenlp .data .causal_dataset import build_train_valid_test_datasets , print_rank_0
57+
5758
5859def add_start_docstrings (* docstr ):
5960 def docstring_decorator (fn ):
@@ -95,7 +96,6 @@ class DataArguments:
9596 input_dir : str = field (
9697 default = None , metadata = {"help" : "The name of the dataset to use (via the datasets library)." }
9798 )
98- cache_prefix : str = field (default = None , metadata = {"help" : "The prefix of the cached dataset." })
9999 split : str = field (default = "949,50,1" , metadata = {"help" : "Train/valid/test data split." })
100100
101101 max_seq_length : int = field (
@@ -111,6 +111,13 @@ class DataArguments:
111111 )
112112 train_data_size : int = field (default = - 1 , metadata = {"help" : "Number of dataset for training" })
113113
114+ data_impl : str = field (default = "mmap" , metadata = {"help" : "The format of the preprocessed data." })
115+ skip_warmup : bool = field (
116+ default = True ,
117+ metadata = {"help" : "Whether to skip the warmup process of mmap files." },
118+ )
119+ data_cache : str = field (default = None , metadata = {"help" : "The path of the cached dataset." })
120+
114121
115122@dataclass
116123class ModelArguments :
@@ -200,7 +207,7 @@ def create_pretrained_dataset(
200207 tokenizer ,
201208):
202209
203- train_valid_test_num_samples = [
210+ train_val_test_num_samples = [
204211 training_args .per_device_train_batch_size
205212 * training_args .dataset_world_size
206213 * training_args .max_steps
@@ -212,74 +219,46 @@ def create_pretrained_dataset(
212219 training_args .per_device_eval_batch_size * training_args .dataset_world_size * training_args .test_iters ,
213220 ]
214221
215- input_prefix = data_file [0 ]
216-
217- for suffix in ["_ids.npy" , "_idx.npz" ]:
218- if not os .path .isfile (input_prefix + suffix ):
219- raise ValueError ("File Not found, %s" % (input_prefix + suffix ))
220-
221- sample_ids = np .load (input_prefix + "_ids.npy" , mmap_mode = "r" , allow_pickle = True )
222- # All documment ids, extend as 1-D array.
223-
224- process_data = np .load (input_prefix + "_idx.npz" )
225- # The len(sample_lens) num of docs
226- # The sum(sample_lens) should equal len(sample_ids)
227- sample_lens = process_data ["lens" ]
228-
229- splits = get_train_valid_test_split_ (data_args .split , len (sample_lens ))
230- assert len (sample_lens ) >= splits [- 1 ], "The document nums should larger than max of splits, but %s < %s" % (
231- len (sample_lens ),
232- splits [- 1 ],
222+ print_rank_0 (" > datasets target sizes (minimum size):" )
223+ print_rank_0 (" train: {}" .format (train_val_test_num_samples [0 ]))
224+ print_rank_0 (" validation: {}" .format (train_val_test_num_samples [1 ]))
225+ print_rank_0 (" test: {}" .format (train_val_test_num_samples [2 ]))
226+
227+ # Build the datasets.
228+ train_dataset , valid_dataset , test_dataset = build_train_valid_test_datasets (
229+ data_prefix = data_file ,
230+ data_impl = data_args .data_impl ,
231+ splits_string = data_args .split ,
232+ train_val_test_num_samples = train_val_test_num_samples ,
233+ seq_length = data_args .max_seq_length ,
234+ seed = training_args .seed ,
235+ skip_warmup = data_args .skip_warmup ,
236+ data_cache_path = data_args .data_cache ,
233237 )
234238
235239 def print_dataset (data , mode = "train" ):
236240 logger .info (f"Sample data for { mode } mode" )
237- input_ids , loss_mask , attention_mask , position_ids , labels = data
241+ # input_ids, loss_mask, attention_mask, position_ids, labels = data
242+ input_ids = data ["text" ]
243+
238244 logger .info (tokenizer ._decode (input_ids ))
239- # logger.info(tokenizer._decode(labels))
240- # logger.info(tokenizer.convert_ids_to_tokens(input_ids))
241-
242- def build_dataset (index , name ):
243- dataset = GPTDataset (
244- file_prefix = os .path .join (data_args .cache_prefix , os .path .basename (input_prefix )),
245- build_data_file = training_args .local_process_index == 0 ,
246- micro_batch_size = training_args .per_device_train_batch_size
247- if name == "train"
248- else training_args .per_device_eval_batch_size ,
249- name = "gpt_" + name ,
250- max_seq_len = data_args .max_seq_length ,
251- num_samples = train_valid_test_num_samples [index ],
252- documents = np .arange (splits [index ], splits [index + 1 ]),
253- sample_ids = sample_ids ,
254- sample_lens = sample_lens ,
255- eos_id = tokenizer .eos_token_id ,
256- seed = training_args .seed ,
257- )
258- print_dataset (dataset [0 ], name )
259- return dataset
260245
261246 from paddlenlp .data import Stack
262247
263248 def _collate_data (data , stack_fn = Stack ()):
264- num_fields = len (data [0 ])
265- out = [None ] * num_fields
266- # 0:input_ids, 1:loss_mask, 2:attention_mask, 3:position_ids, 4:labels
267- for i in (0 , 1 , 2 , 3 , 4 ):
268- out [i ] = stack_fn ([x [i ] for x in data ])
249+ tokens_ = stack_fn (x ["text" ] for x in data )
250+
251+ labels = tokens_ [:, 1 :]
252+ tokens = tokens_ [:, :- 1 ]
269253
270254 return {
271- "input_ids" : out [0 ],
272- # "token_type_ids": out[1],
273- # "attention_mask": out[2],
274- # "loss_mask": out[3],
275- "labels" : out [4 ],
255+ "input_ids" : tokens ,
256+ "labels" : labels ,
276257 }
277258
278- # Note, data should be broardcast to all devices.
279- # for train, valid, test, the distinct data num is data_world_size
280- train_dataset = build_dataset (0 , "train" )
281- valid_dataset = build_dataset (1 , "valid" )
282- test_dataset = build_dataset (2 , "test" )
259+ print_dataset (train_dataset [0 ], "train" )
260+ print_dataset (valid_dataset [0 ], "valid" )
261+ print_dataset (test_dataset [0 ], "test" )
283262
284263 return train_dataset , valid_dataset , test_dataset , _collate_data
285264
@@ -292,9 +271,10 @@ def get_train_data_file(args):
292271 files = [
293272 os .path .join (args .input_dir , f )
294273 for f in os .listdir (args .input_dir )
295- if (os .path .isfile (os .path .join (args .input_dir , f )) and "_idx.npz" in str (f ))
274+ if (os .path .isfile (os .path .join (args .input_dir , f )) and ( "_idx.npz" in str (f ) or ".idx" in str ( f ) ))
296275 ]
297276 files = [x .replace ("_idx.npz" , "" ) for x in files ]
277+ files = [x .replace (".idx" , "" ) for x in files ] # add
298278
299279 if len (files ) > 1 :
300280 ret = []
@@ -396,10 +376,8 @@ def main():
396376 if model_args .tokenizer_name_or_path is None :
397377 model_args .tokenizer_name_or_path = model_args .model_name_or_path
398378
399- if data_args .cache_prefix is None :
400- data_args .cache_prefix = data_args .input_dir
401- else :
402- os .makedirs (data_args .cache_prefix , exist_ok = True )
379+ if data_args .data_cache is not None :
380+ os .makedirs (data_args .data_cache , exist_ok = True )
403381
404382 set_seed (training_args )
405383 paddle .set_device (training_args .device )
0 commit comments