Skip to content

Commit b6d8afa

Browse files
tomvdwThe TensorFlow Datasets Authors
authored andcommitted
More efficient determination of shard lengths and total size
PiperOrigin-RevId: 812896990
1 parent e8f59cd commit b6d8afa

File tree

4 files changed

+71
-18
lines changed

4 files changed

+71
-18
lines changed

tensorflow_datasets/core/file_adapters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import abc
2121
from collections.abc import Iterable, Iterator
22+
import concurrent.futures
2223
import enum
2324
import itertools
2425
import os
@@ -187,6 +188,31 @@ def num_examples(cls, filename: epath.PathLike) -> int:
187188
n += 1
188189
return n
189190

191+
@classmethod
192+
def shard_lengths_and_sizes(
193+
cls,
194+
filename_template: naming.ShardedFileTemplate,
195+
num_shards: int | None = None,
196+
) -> list[tuple[int, int]]:
197+
"""Returns the number of examples in each shard."""
198+
if num_shards is not None:
199+
shards = filename_template.sharded_filepaths(num_shards=num_shards)
200+
else:
201+
shards = filename_template.data_dir.glob(filename_template.glob_pattern())
202+
shards = sorted([os.fspath(s) for s in shards])
203+
204+
def _get_length_and_size(shard: tuple[int, str]) -> tuple[int, int, int]:
205+
index, shard = shard
206+
length = cls.num_examples(shard)
207+
size = epath.Path(shard).stat().length
208+
return index, length, size
209+
210+
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
211+
results = executor.map(_get_length_and_size, enumerate(shards))
212+
# Sort results by the index and remove the index from the tuple.
213+
sorted_results = sorted(results, key=lambda x: x[0])
214+
return [(length, size) for _, length, size in sorted_results]
215+
190216

191217
class TfRecordFileAdapter(FileAdapter):
192218
"""File adapter for TFRecord file format."""

tensorflow_datasets/core/file_adapters_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
import pathlib
2020
from typing import Type, TypeAlias
2121

22+
from etils import epath
2223
import pytest
2324
from tensorflow_datasets import testing
2425
from tensorflow_datasets.core import dataset_builder
2526
from tensorflow_datasets.core import file_adapters
27+
from tensorflow_datasets.core import naming
2628

2729

2830
FileFormat: TypeAlias = file_adapters.FileFormat
@@ -138,3 +140,41 @@ def test_prase_file_format(format_enum_value, file_format):
138140
def test_convert_path_to_file_format(path, file_format, expected_path):
139141
converted_path = file_adapters.convert_path_to_file_format(path, file_format)
140142
assert os.fspath(converted_path) == expected_path
143+
144+
145+
@pytest.mark.parametrize(
146+
'adapter_cls',
147+
(
148+
(file_adapters.TfRecordFileAdapter),
149+
(file_adapters.ArrayRecordFileAdapter),
150+
),
151+
)
152+
def test_shard_lengths(
153+
tmp_path: pathlib.Path, adapter_cls: file_adapters.FileAdapter
154+
):
155+
file_template = naming.ShardedFileTemplate(
156+
data_dir=tmp_path,
157+
dataset_name='data',
158+
filetype_suffix=adapter_cls.FILE_SUFFIX,
159+
split='train',
160+
)
161+
tmp_path_1 = file_template.sharded_filepath(shard_index=0, num_shards=2)
162+
tmp_path_2 = file_template.sharded_filepath(shard_index=1, num_shards=2)
163+
adapter_cls.write_examples(
164+
tmp_path_1, [(0, b'0'), (1, b'1'), (2, b'2222'), (3, b'33333')]
165+
)
166+
adapter_cls.write_examples(tmp_path_2, [(3, b'3'), (4, b'4'), (5, b'555')])
167+
size_1 = epath.Path(tmp_path_1).stat().length
168+
size_2 = epath.Path(tmp_path_2).stat().length
169+
expected_shard_lengths = [(4, size_1), (3, size_2)]
170+
171+
# First test without passing the number of shards explicitly.
172+
actual_no_num_shards = adapter_cls.shard_lengths_and_sizes(file_template)
173+
assert actual_no_num_shards == expected_shard_lengths, 'no num_shards passed'
174+
175+
# Now test with passing the number of shards explicitly.
176+
actual_with_num_shards = adapter_cls.shard_lengths_and_sizes(
177+
file_template,
178+
num_shards=2,
179+
)
180+
assert actual_with_num_shards == expected_shard_lengths, 'num_shards passed'

tensorflow_datasets/core/utils/file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import collections
21-
from collections.abc import Iterator, Sequence
21+
from collections.abc import Iterable, Iterator, Sequence
2222
import contextlib
2323
import dataclasses
2424
import functools

tensorflow_datasets/core/writer.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import annotations
1919

2020
from collections.abc import Iterable, Iterator, Sequence
21-
import concurrent.futures
2221
import dataclasses
2322
import functools
2423
import itertools
@@ -822,23 +821,11 @@ def finalize(self) -> tuple[list[int], int]:
822821
in each shard, and size of the files (in bytes).
823822
"""
824823
logging.info("Finalizing writer for %s", self._filename_template.split)
825-
# We don't know the number of shards, the length of each shard, nor the
826-
# total size, so we compute them here.
827-
shards = self._filename_template.data_dir.glob(
828-
self._filename_template.glob_pattern()
824+
shard_lengths_and_sizes = self._file_adapter.shard_lengths_and_sizes(
825+
self._filename_template, num_shards=self._num_shards
829826
)
830-
831-
def _get_length_and_size(shard: epath.Path) -> tuple[epath.Path, int, int]:
832-
length = self._file_adapter.num_examples(shard)
833-
size = shard.stat().length
834-
return shard, length, size
835-
836-
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
837-
shard_sizes = executor.map(_get_length_and_size, shards)
838-
839-
shard_sizes = sorted(shard_sizes, key=lambda x: x[0])
840-
shard_lengths: list[int] = [x[1] for x in shard_sizes]
841-
total_size_bytes: int = sum([x[2] for x in shard_sizes])
827+
shard_lengths = [length for length, _ in shard_lengths_and_sizes]
828+
total_size_bytes = sum(size for _, size in shard_lengths_and_sizes)
842829

843830
logging.info(
844831
"Found %d shards with a total size of %d bytes.",

0 commit comments

Comments
 (0)