@@ -135,7 +135,9 @@ def get_num_samples_in_batch(batch: dict) -> int:
135
135
136
136
response_tokens = len (batch ['labels' ]) if 'labels' in batch else 0
137
137
138
- return {'ntokens' : input_ids_tokens + decoder_input_ids_tokens + response_tokens }
138
+ return {
139
+ 'ntokens' : input_ids_tokens + decoder_input_ids_tokens + response_tokens
140
+ }
139
141
140
142
141
143
def token_counts (FT_API_args ):
@@ -270,7 +272,7 @@ def count_shards(mds_root: str):
270
272
merge_shard_groups )
271
273
272
274
log = logging .getLogger (__name__ )
273
- DONE_FILENAME = '. text_to_mds_conversion_done'
275
+ DONE_FILENAME = '/Volumes/main/mosaic_hackathon/managed-volume/ text_to_mds_conversion_done'
274
276
275
277
276
278
def parse_args (tokenizer ,
@@ -499,6 +501,8 @@ def download_and_convert(
499
501
bos_text (str): Text to prepend to each example to separate concatenated samples
500
502
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
501
503
compression (str): The compression algorithm to use for MDS writing
504
+ Returns:
505
+ (int): token count of the current group
502
506
"""
503
507
object_store = maybe_create_object_store_from_uri (input_folder )
504
508
@@ -521,14 +525,18 @@ def download_and_convert(
521
525
no_wrap = no_wrap ,
522
526
)
523
527
524
- columns = {'tokens' : 'bytes' }
528
+ token_count = sum ([ 1 for _ in dataset ])
529
+
530
+ # columns = {'tokens': 'bytes'}
525
531
526
- log .info ('Converting to MDS format...' )
527
- with MDSWriter (out = output_folder ,
528
- columns = columns ,
529
- compression = compression ) as out :
530
- for sample in tqdm (dataset ):
531
- out .write (sample )
532
+ # log.info('Converting to MDS format...')
533
+ # with MDSWriter(out=output_folder,
534
+ # columns=columns,
535
+ # compression=compression) as out:
536
+ # for sample in tqdm(dataset):
537
+ # out.write(sample)
538
+
539
+ return token_count
532
540
533
541
534
542
def is_remote_path (path : str ) -> bool :
@@ -616,7 +624,7 @@ def convert_text_to_mds(
616
624
processes : int ,
617
625
args_str : str ,
618
626
reprocess : bool ,
619
- ):
627
+ )-> int :
620
628
"""Convert a folder of text files to MDS format.
621
629
622
630
Args:
@@ -631,6 +639,8 @@ def convert_text_to_mds(
631
639
processes (int): The number of processes to use.
632
640
args_str (str): String representation of the arguments
633
641
reprocess (bool): Whether to always reprocess the given folder of text files
642
+ Returns:
643
+ (int): total tokens of the dataset
634
644
"""
635
645
is_remote_output = is_remote_path (output_folder )
636
646
@@ -658,12 +668,13 @@ def convert_text_to_mds(
658
668
processes , tokenizer_name , concat_tokens , eos_text ,
659
669
bos_text , no_wrap , compression )
660
670
with ProcessPoolExecutor (max_workers = processes ) as executor :
661
- list (executor .map (download_and_convert_starargs , args ))
671
+ pool = list (executor .map (download_and_convert_starargs , args ))
662
672
663
673
# Merge the mds shards from each of the processes into a single folder
664
- merge_shard_groups (local_output_folder )
674
+ # merge_shard_groups(local_output_folder)
675
+ total_tokens = sum (pool )
665
676
else :
666
- download_and_convert (object_names , local_output_folder , input_folder ,
677
+ total_tokens = download_and_convert (object_names , local_output_folder , input_folder ,
667
678
tokenizer_name , concat_tokens , eos_text , bos_text ,
668
679
no_wrap , compression )
669
680
@@ -683,6 +694,8 @@ def convert_text_to_mds(
683
694
output_object_store .upload_object (
684
695
remote_path , os .path .join (local_output_folder , file ))
685
696
697
+ return total_tokens
698
+
686
699
687
700
def _args_str (original_args : Namespace ) -> str :
688
701
"""Create a string from the args to determine whether to reprocess.
@@ -801,8 +814,8 @@ def plot_hist(data, save_plot_path=None):
801
814
802
815
# Aesthetics
803
816
plt .title ('Histogram of Token Counts' )
804
- plt .xlabel ('Token Count ' )
805
- plt .ylabel ('Frequency' )
817
+ plt .xlabel ('Number of Tokens per Sample ' )
818
+ plt .ylabel ('Count of Frequency' )
806
819
807
820
# Grid and Layout
808
821
plt .grid (axis = 'y' , alpha = 0.75 )
@@ -855,6 +868,19 @@ def pandas_processing_fn(df: pd.DataFrame,
855
868
hf_dataset = hf_datasets .Dataset .from_pandas (df = df )
856
869
tokenizer = AutoTokenizer .from_pretrained (args ['tokenizer' ])
857
870
tokenizer .model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace
871
+
872
+ if bos_text + eos_text == '' :
873
+ test_tokens = tokenizer ('test' )
874
+ if test_tokens ['input_ids' ][
875
+ 0 ] != tokenizer .bos_token_id and test_tokens ['input_ids' ][
876
+ - 1 ] != tokenizer .eos_token_id :
877
+ tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. '
878
+ tok_error_msg += 'Concatenating with this tokenizer will result in sequences being '
879
+ tok_error_msg += 'attached without a separating token. Please use another tokenizer, '
880
+ tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. '
881
+ tok_error_msg += '--bos_text=<|endoftext|>.'
882
+ raise ValueError (tok_error_msg )
883
+
858
884
dataset = ConcatTokensDataset (
859
885
hf_dataset = hf_dataset ,
860
886
max_length = args .get ('concat_tokens' , None ),
@@ -893,15 +919,16 @@ def pandas_processing_fn(df: pd.DataFrame,
893
919
except ImportError as e :
894
920
e .msg = get_import_exception_message (e .name ,
895
921
extra_deps = 'spark' ) # pyright: ignore
896
- raise e
922
+ # raise e
897
923
898
924
try :
899
925
from dask .dataframe import DataFrame as DaskDataFrame
900
926
from dask .distributed import Client , LocalCluster
901
927
except ImportError as e :
902
928
e .msg = get_import_exception_message (e .name ,
903
929
extra_deps = 'dask' ) # pyright: ignore
904
- raise e
930
+ #raise e
931
+ DaskDataFrame = None
905
932
906
933
try :
907
934
from streaming import MDSWriter
@@ -912,7 +939,7 @@ def pandas_processing_fn(df: pd.DataFrame,
912
939
except ImportError as e :
913
940
e .msg = get_import_exception_message (
914
941
e .name , extra_deps = 'streaming' ) # pyright: ignore
915
- raise e
942
+ # raise e
916
943
917
944
logger = logging .getLogger (__name__ )
918
945
0 commit comments