Skip to content

Commit 0d56490

Browse files
authored
support the loss mask for the pretrain (#8034)
1 parent dff7dcc commit 0d56490

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

paddlenlp/data/causal_dataset.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,26 +363,52 @@ def __getitem__(self, idx):
363363
if doc_index_f == doc_index_l:
364364
doc_ids.append(self.doc_idx[doc_index_f])
365365

366-
sample = self.indexed_dataset.get(
366+
sample, mask = self.indexed_dataset.get(
367367
self.doc_idx[doc_index_f], offset=offset_f, length=offset_l - offset_f + 1
368368
)
369369
else:
370370
# Otherwise, get the rest of the initial document.
371371
doc_ids.append(self.doc_idx[doc_index_f])
372-
sample_list = [self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)]
372+
sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
373+
append_mask = True
374+
if mask is None:
375+
append_mask = False
376+
377+
sample_list = [sample]
378+
mask_list = []
379+
mask_list = [mask]
373380
# Loop over all in between documents and add the entire document.
374381
for i in range(doc_index_f + 1, doc_index_l):
375382
doc_ids.append(self.doc_idx[i])
376-
sample_list.append(self.indexed_dataset.get(self.doc_idx[i]))
383+
sample, mask = self.indexed_dataset.get(self.doc_idx[i])
384+
sample_list.append(sample)
385+
if append_mask:
386+
mask_list.append(mask)
387+
377388
# And finally add the relevant portion of last document.
378389
doc_ids.append(self.doc_idx[doc_index_l])
379-
sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1))
390+
sample, mask = self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1)
391+
sample_list.append(sample)
392+
if append_mask:
393+
mask_list.append(mask)
380394
sample = np.concatenate(sample_list)
395+
if append_mask:
396+
mask = np.concatenate(mask_list)
381397
# print(sample)
382398
if self.return_doc_ids: # for retro preprocessing
383-
return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)}
399+
if mask is None:
400+
return {"text": np.array(sample, dtype=np.int64), "doc_ids": np.array(doc_ids, dtype=np.int64)}
401+
else:
402+
return {
403+
"text": np.array(sample, dtype=np.int64),
404+
"doc_ids": np.array(doc_ids, dtype=np.int64),
405+
"mask": np.array(mask, dtype=np.int64),
406+
}
384407
else:
385-
return {"text": np.array(sample, dtype=np.int64)}
408+
if mask is None:
409+
return {"text": np.array(sample, dtype=np.int64)}
410+
else:
411+
return {"text": np.array(sample, dtype=np.int64), "mask": np.array(mask, dtype=np.int64)}
386412

387413

388414
def _build_index_mappings(

paddlenlp/data/indexed_dataset.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def data_file_path(prefix_path):
124124
return prefix_path + ".bin"
125125

126126

127+
def loss_mask_file_path(prefix_path):
128+
return prefix_path + ".lsm"
129+
130+
127131
def create_doc_idx(sizes):
128132
doc_idx = [0]
129133
for i, s in enumerate(sizes):
@@ -444,6 +448,7 @@ def __init__(self, path, skip_warmup=False):
444448
self._path = None
445449
self._index = None
446450
self._bin_buffer = None
451+
self._loss_mask_buffer = None
447452

448453
self._do_init(path, skip_warmup)
449454

@@ -466,12 +471,18 @@ def _do_init(self, path, skip_warmup):
466471
_warmup_mmap_file(data_file_path(self._path))
467472
print_rank_0(" creating numpy buffer of mmap...")
468473
self._bin_buffer_mmap = np.memmap(data_file_path(self._path), mode="r", order="C")
474+
if os.path.exists(loss_mask_file_path(self._path)):
475+
self._loss_mask_buffer_mmap = np.memmap(loss_mask_file_path(self._path), mode="r", order="C")
476+
self._loss_mask_buffer = memoryview(self._loss_mask_buffer_mmap)
469477
print_rank_0(" creating memory view of numpy buffer...")
470478
self._bin_buffer = memoryview(self._bin_buffer_mmap)
471479

472480
def __del__(self):
473481
self._bin_buffer_mmap._mmap.close()
482+
if hasattr(self, "_loss_mask_buffer_mmap"):
483+
self._loss_mask_buffer_mmap._mmap.close()
474484
del self._bin_buffer_mmap
485+
del self._loss_mask_buffer
475486
del self._index
476487

477488
def __len__(self):
@@ -507,8 +518,12 @@ def get(self, idx, offset=0, length=None):
507518
if length is None:
508519
length = size - offset
509520
ptr += offset * np.dtype(self._index.dtype).itemsize
521+
mask_ptr = ptr // np.dtype(self._index.dtype).itemsize
510522
np_array = np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=length, offset=ptr)
511-
return np_array
523+
mask_array = None
524+
if self._loss_mask_buffer is not None:
525+
mask_array = np.frombuffer(self._loss_mask_buffer, dtype=np.uint8, count=length, offset=mask_ptr)
526+
return np_array, mask_array
512527

513528
@property
514529
def sizes(self):
@@ -533,20 +548,28 @@ def exists(path):
533548
return os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
534549

535550

536-
def make_builder(out_file, impl, save_dtype):
551+
def make_builder(out_file, impl, save_dtype, loss_mask_file=None):
537552
if impl == "mmap":
538-
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype)
553+
return MMapIndexedDatasetBuilder(out_file, dtype=save_dtype, loss_mask_file=loss_mask_file)
539554
else:
540555
return IndexedDatasetBuilder(out_file, dtype=save_dtype)
541556

542557

543558
class MMapIndexedDatasetBuilder(object):
544-
def __init__(self, out_file, dtype):
559+
def __init__(self, out_file, dtype, loss_mask_file=None):
545560
self._data_file = open(out_file, "wb")
561+
self._loss_mask_file = None
562+
if loss_mask_file is not None:
563+
self._loss_mask_file = open(loss_mask_file, "wb")
546564
self._dtype = dtype
547565
self._sizes = []
548566
self._doc_idx = [0]
549567

568+
def flush_loss_mask_item(self, loss_mask_lst):
569+
for loss_mask in loss_mask_lst:
570+
tensor = np.array(loss_mask, dtype=np.uint8)
571+
self._loss_mask_file.write(tensor.tobytes(order="C"))
572+
550573
def add_item(self, tensor):
551574
tensor = np.array(tensor, dtype=self._dtype)
552575
self._data_file.write(tensor.tobytes(order="C"))

0 commit comments

Comments
 (0)