Skip to content

Commit bb7f9b5

Browse files
author
Matthew Hoffman
committed
Set initial value if there are already existing cache files
#7434 (comment)
1 parent 79dc83b commit bb7f9b5

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/datasets/arrow_dataset.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3070,15 +3070,6 @@ def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
30703070
return Dataset.from_file(shard_kwargs["cache_file_name"], info=info, split=shard.split)
30713071
raise NonExistentDatasetError
30723072

3073-
def pbar_total(num_shards: int, batch_size: Optional[int]) -> int:
3074-
total = len(self)
3075-
if len(existing_cache_files) < num_shards:
3076-
total -= len(existing_cache_files) * total // num_shards
3077-
if batched and drop_last_batch:
3078-
batch_size = batch_size or 1
3079-
return total // num_shards // batch_size * num_shards * batch_size
3080-
return total
3081-
30823073
existing_cache_file_map: dict[int, list[str]] = defaultdict(list)
30833074
if cache_file_name is not None:
30843075
if os.path.exists(cache_file_name):
@@ -3199,9 +3190,17 @@ def format_new_fingerprint(new_fingerprint: str, rank: int) -> str:
31993190
"missing from the cache."
32003191
)
32013192

3193+
pbar_total = len(self)
3194+
pbar_initial = len(existing_cache_files) * pbar_total // num_shards
3195+
if batched and drop_last_batch:
3196+
batch_size = batch_size or 1
3197+
pbar_initial = pbar_initial // num_shards // batch_size * num_shards * batch_size
3198+
pbar_total = pbar_total // num_shards // batch_size * num_shards * batch_size
3199+
32023200
with hf_tqdm(
32033201
unit=" examples",
3204-
total=pbar_total(num_shards, batch_size),
3202+
initial=pbar_initial,
3203+
total=pbar_total,
32053204
desc=(desc or "Map") + (f" (num_proc={num_proc})" if num_proc is not None and num_proc > 1 else ""),
32063205
) as pbar:
32073206
shards_done = 0

0 commit comments

Comments
 (0)