@@ -3070,15 +3070,6 @@ def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
30703070 return Dataset .from_file (shard_kwargs ["cache_file_name" ], info = info , split = shard .split )
30713071 raise NonExistentDatasetError
30723072
3073- def pbar_total (num_shards : int , batch_size : Optional [int ]) -> int :
3074- total = len (self )
3075- if len (existing_cache_files ) < num_shards :
3076- total -= len (existing_cache_files ) * total // num_shards
3077- if batched and drop_last_batch :
3078- batch_size = batch_size or 1
3079- return total // num_shards // batch_size * num_shards * batch_size
3080- return total
3081-
30823073 existing_cache_file_map : dict [int , list [str ]] = defaultdict (list )
30833074 if cache_file_name is not None :
30843075 if os .path .exists (cache_file_name ):
@@ -3199,9 +3190,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31993190 "missing from the cache."
32003191 )
32013192
3193+ pbar_total = len (self )
3194+ pbar_initial = len (existing_cache_files ) * pbar_total // num_shards
3195+ if batched and drop_last_batch :
3196+ batch_size = batch_size or 1
3197+ pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size
3198+ pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size
3199+
32023200 with hf_tqdm (
32033201 unit = " examples" ,
3204- total = pbar_total (num_shards , batch_size ),
3202+ initial = pbar_initial ,
3203+ total = pbar_total ,
32053204 desc = (desc or "Map" ) + (f" (num_proc={ num_proc } )" if num_proc is not None and num_proc > 1 else "" ),
32063205 ) as pbar :
32073206 shards_done = 0
0 commit comments