@@ -124,6 +124,10 @@ def data_file_path(prefix_path):
124124 return prefix_path + ".bin"
125125
126126
127+ def loss_mask_file_path (prefix_path ):
128+ return prefix_path + ".lsm"
129+
130+
127131def create_doc_idx (sizes ):
128132 doc_idx = [0 ]
129133 for i , s in enumerate (sizes ):
@@ -444,6 +448,7 @@ def __init__(self, path, skip_warmup=False):
444448 self ._path = None
445449 self ._index = None
446450 self ._bin_buffer = None
451+ self ._loss_mask_buffer = None
447452
448453 self ._do_init (path , skip_warmup )
449454
@@ -466,12 +471,18 @@ def _do_init(self, path, skip_warmup):
466471 _warmup_mmap_file (data_file_path (self ._path ))
467472 print_rank_0 (" creating numpy buffer of mmap..." )
468473 self ._bin_buffer_mmap = np .memmap (data_file_path (self ._path ), mode = "r" , order = "C" )
474+ if os .path .exists (loss_mask_file_path (self ._path )):
475+ self ._loss_mask_buffer_mmap = np .memmap (loss_mask_file_path (self ._path ), mode = "r" , order = "C" )
476+ self ._loss_mask_buffer = memoryview (self ._loss_mask_buffer_mmap )
469477 print_rank_0 (" creating memory view of numpy buffer..." )
470478 self ._bin_buffer = memoryview (self ._bin_buffer_mmap )
471479
472480 def __del__ (self ):
473481 self ._bin_buffer_mmap ._mmap .close ()
482+ if hasattr (self , "_loss_mask_buffer_mmap" ):
483+ self ._loss_mask_buffer_mmap ._mmap .close ()
474484 del self ._bin_buffer_mmap
485+ del self ._loss_mask_buffer
475486 del self ._index
476487
477488 def __len__ (self ):
@@ -507,8 +518,12 @@ def get(self, idx, offset=0, length=None):
507518 if length is None :
508519 length = size - offset
509520 ptr += offset * np .dtype (self ._index .dtype ).itemsize
521+ mask_ptr = ptr // np .dtype (self ._index .dtype ).itemsize
510522 np_array = np .frombuffer (self ._bin_buffer , dtype = self ._index .dtype , count = length , offset = ptr )
511- return np_array
523+ mask_array = None
524+ if self ._loss_mask_buffer is not None :
525+ mask_array = np .frombuffer (self ._loss_mask_buffer , dtype = np .uint8 , count = length , offset = mask_ptr )
526+ return np_array , mask_array
512527
513528 @property
514529 def sizes (self ):
@@ -533,20 +548,28 @@ def exists(path):
533548 return os .path .exists (index_file_path (path )) and os .path .exists (data_file_path (path ))
534549
535550
536- def make_builder (out_file , impl , save_dtype ):
551+ def make_builder (out_file , impl , save_dtype , loss_mask_file = None ):
537552 if impl == "mmap" :
538- return MMapIndexedDatasetBuilder (out_file , dtype = save_dtype )
553+ return MMapIndexedDatasetBuilder (out_file , dtype = save_dtype , loss_mask_file = loss_mask_file )
539554 else :
540555 return IndexedDatasetBuilder (out_file , dtype = save_dtype )
541556
542557
543558class MMapIndexedDatasetBuilder (object ):
544- def __init__ (self , out_file , dtype ):
559+ def __init__ (self , out_file , dtype , loss_mask_file = None ):
545560 self ._data_file = open (out_file , "wb" )
561+ self ._loss_mask_file = None
562+ if loss_mask_file is not None :
563+ self ._loss_mask_file = open (loss_mask_file , "wb" )
546564 self ._dtype = dtype
547565 self ._sizes = []
548566 self ._doc_idx = [0 ]
549567
568+ def flush_loss_mask_item (self , loss_mask_lst ):
569+ for loss_mask in loss_mask_lst :
570+ tensor = np .array (loss_mask , dtype = np .uint8 )
571+ self ._loss_mask_file .write (tensor .tobytes (order = "C" ))
572+
550573 def add_item (self , tensor ):
551574 tensor = np .array (tensor , dtype = self ._dtype )
552575 self ._data_file .write (tensor .tobytes (order = "C" ))
0 commit comments