Skip to content

Commit f1cfe9e

Browse files
XiaohanZhangCMUxiaohanzhan-dbmvpatel2000
authored
Validation (#898)
* add validation script * update * change token count function * reorganize cells * Add unit tests * Add a printout for CPT * update question * Add questions * Fix lints * update format * update * nb source * add validation script * update * change token count function * reorganize cells * Add unit tests * Add a printout for CPT * update question * Add questions * Fix lints * update format * update * nb source * Remove license insert for validation notebook * Add validation utils * Minor cleanups (#858) * nits * logger * add log * lint * update utils/__init__.py to include extra validation functions * update notebook * update * update * Read UC delta table (#773) * initial commit * use databricks-sql to read delta table and convert to json * update * update * update * add mocked unittest * Fix lints * update * update * restructure code * Add timer for optimizing * Add db-connect * add wrapper * update * add install dbconnect * update * update * patch dbconnect to allow multiple return formats * update * add arrow * use compression * clean up * Add cluster rt check * Fix lints * remove patch.py for CI * update * update * updat * update * fix tests * fix lint * update * update * Add more tests * update * update * update * change to download_json * update * fix lints * Add decompressed option for arrow * format json to jsonl * Add comments * Make cf_collect_type global option * fix comments * fix lints * fix comments * Fix lints * change to use workspaceclient * Add CPT support * Rewire method assignment logic * Fix bug in stripping https * Add tests for rewired method assignment logic * Fix lints * Fix lints * Removed logger set_level * Remove pyspark. It conflicts with databricks-connect * Update the comment * skip cluster version check when cluster_id is serverless * Add use_serverless flag * update tests with use_serverless flag * Fix lints --------- Co-authored-by: Xiaohan Zhang <[email protected]> * Add download remote function to util * update * remove fused layernorm (#859) * update * update * update * update * update * update * update * update * update * Remove hardcoded combined.jsonl with a flag (#861) * Remove hardcoded combined.jsonl with a flag * update * change output_json_path output_json_folder --------- Co-authored-by: Xiaohan Zhang <[email protected]> * bump (#828) * Add dask and dataframe_to_mds * update * update * update * update * Add notebook * update * update * remove script and tests, keep notebook * update * update * update * update * Always initialize dist (#864) * fix dev * lint * remove gpu * updated notebook * remove scripts keep notebook * update notebook. rephrase. * update * Add response tokens * update * update * Disable MDSWrite, return token counts * Change plot settings * update notebook * update --------- Co-authored-by: Xiaohan Zhang <[email protected]> Co-authored-by: xiaohanzhan-db <xiaohanzhan-db> Co-authored-by: Mihir Patel <[email protected]>
1 parent a9218d6 commit f1cfe9e

File tree

2 files changed

+815
-1162
lines changed

2 files changed

+815
-1162
lines changed

llmfoundry/utils/validation_utils.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def get_num_samples_in_batch(batch: dict) -> int:
135135

136136
response_tokens = len(batch['labels']) if 'labels' in batch else 0
137137

138-
return {'ntokens': input_ids_tokens + decoder_input_ids_tokens + response_tokens}
138+
return {
139+
'ntokens': input_ids_tokens + decoder_input_ids_tokens + response_tokens
140+
}
139141

140142

141143
def token_counts(FT_API_args):
@@ -270,7 +272,7 @@ def count_shards(mds_root: str):
270272
merge_shard_groups)
271273

272274
log = logging.getLogger(__name__)
273-
DONE_FILENAME = '.text_to_mds_conversion_done'
275+
DONE_FILENAME = '/Volumes/main/mosaic_hackathon/managed-volume/text_to_mds_conversion_done'
274276

275277

276278
def parse_args(tokenizer,
@@ -499,6 +501,8 @@ def download_and_convert(
499501
bos_text (str): Text to prepend to each example to separate concatenated samples
500502
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
501503
compression (str): The compression algorithm to use for MDS writing
504+
Returns:
505+
(int): token count of the current group
502506
"""
503507
object_store = maybe_create_object_store_from_uri(input_folder)
504508

@@ -521,14 +525,18 @@ def download_and_convert(
521525
no_wrap=no_wrap,
522526
)
523527

524-
columns = {'tokens': 'bytes'}
528+
token_count = sum([ 1 for _ in dataset])
529+
530+
# columns = {'tokens': 'bytes'}
525531

526-
log.info('Converting to MDS format...')
527-
with MDSWriter(out=output_folder,
528-
columns=columns,
529-
compression=compression) as out:
530-
for sample in tqdm(dataset):
531-
out.write(sample)
532+
# log.info('Converting to MDS format...')
533+
# with MDSWriter(out=output_folder,
534+
# columns=columns,
535+
# compression=compression) as out:
536+
# for sample in tqdm(dataset):
537+
# out.write(sample)
538+
539+
return token_count
532540

533541

534542
def is_remote_path(path: str) -> bool:
@@ -616,7 +624,7 @@ def convert_text_to_mds(
616624
processes: int,
617625
args_str: str,
618626
reprocess: bool,
619-
):
627+
)->int:
620628
"""Convert a folder of text files to MDS format.
621629
622630
Args:
@@ -631,6 +639,8 @@ def convert_text_to_mds(
631639
processes (int): The number of processes to use.
632640
args_str (str): String representation of the arguments
633641
reprocess (bool): Whether to always reprocess the given folder of text files
642+
Returns:
643+
(int): total tokens of the dataset
634644
"""
635645
is_remote_output = is_remote_path(output_folder)
636646

@@ -658,12 +668,13 @@ def convert_text_to_mds(
658668
processes, tokenizer_name, concat_tokens, eos_text,
659669
bos_text, no_wrap, compression)
660670
with ProcessPoolExecutor(max_workers=processes) as executor:
661-
list(executor.map(download_and_convert_starargs, args))
671+
pool = list(executor.map(download_and_convert_starargs, args))
662672

663673
# Merge the mds shards from each of the processes into a single folder
664-
merge_shard_groups(local_output_folder)
674+
# merge_shard_groups(local_output_folder)
675+
total_tokens = sum(pool)
665676
else:
666-
download_and_convert(object_names, local_output_folder, input_folder,
677+
total_tokens = download_and_convert(object_names, local_output_folder, input_folder,
667678
tokenizer_name, concat_tokens, eos_text, bos_text,
668679
no_wrap, compression)
669680

@@ -683,6 +694,8 @@ def convert_text_to_mds(
683694
output_object_store.upload_object(
684695
remote_path, os.path.join(local_output_folder, file))
685696

697+
return total_tokens
698+
686699

687700
def _args_str(original_args: Namespace) -> str:
688701
"""Create a string from the args to determine whether to reprocess.
@@ -801,8 +814,8 @@ def plot_hist(data, save_plot_path=None):
801814

802815
# Aesthetics
803816
plt.title('Histogram of Token Counts')
804-
plt.xlabel('Token Count')
805-
plt.ylabel('Frequency')
817+
plt.xlabel('Number of Tokens per Sample')
818+
plt.ylabel('Count of Frequency')
806819

807820
# Grid and Layout
808821
plt.grid(axis='y', alpha=0.75)
@@ -855,6 +868,19 @@ def pandas_processing_fn(df: pd.DataFrame,
855868
hf_dataset = hf_datasets.Dataset.from_pandas(df=df)
856869
tokenizer = AutoTokenizer.from_pretrained(args['tokenizer'])
857870
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace
871+
872+
if bos_text + eos_text == '':
873+
test_tokens = tokenizer('test')
874+
if test_tokens['input_ids'][
875+
0] != tokenizer.bos_token_id and test_tokens['input_ids'][
876+
-1] != tokenizer.eos_token_id:
877+
tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. '
878+
tok_error_msg += 'Concatenating with this tokenizer will result in sequences being '
879+
tok_error_msg += 'attached without a separating token. Please use another tokenizer, '
880+
tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. '
881+
tok_error_msg += '--bos_text=<|endoftext|>.'
882+
raise ValueError(tok_error_msg)
883+
858884
dataset = ConcatTokensDataset(
859885
hf_dataset=hf_dataset,
860886
max_length=args.get('concat_tokens', None),
@@ -893,15 +919,16 @@ def pandas_processing_fn(df: pd.DataFrame,
893919
except ImportError as e:
894920
e.msg = get_import_exception_message(e.name,
895921
extra_deps='spark') # pyright: ignore
896-
raise e
922+
#raise e
897923

898924
try:
899925
from dask.dataframe import DataFrame as DaskDataFrame
900926
from dask.distributed import Client, LocalCluster
901927
except ImportError as e:
902928
e.msg = get_import_exception_message(e.name,
903929
extra_deps='dask') # pyright: ignore
904-
raise e
930+
#raise e
931+
DaskDataFrame = None
905932

906933
try:
907934
from streaming import MDSWriter
@@ -912,7 +939,7 @@ def pandas_processing_fn(df: pd.DataFrame,
912939
except ImportError as e:
913940
e.msg = get_import_exception_message(
914941
e.name, extra_deps='streaming') # pyright: ignore
915-
raise e
942+
#raise e
916943

917944
logger = logging.getLogger(__name__)
918945

0 commit comments

Comments
 (0)