Skip to content

Commit 79dc83b

Browse files
author
Matthew Hoffman
committed
Simplify existing existing_cache_file_map with string_to_dict
#7434 (comment)
1 parent 7f50b98 commit 79dc83b

File tree

1 file changed

+19
-34
lines changed

1 file changed

+19
-34
lines changed

src/datasets/arrow_dataset.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3052,6 +3052,13 @@ def map(
30523052
cache_file_name = self._get_cache_file_path(new_fingerprint)
30533053
dataset_kwargs["cache_file_name"] = cache_file_name
30543054

3055+
if cache_file_name is not None:
3056+
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
3057+
if not cache_file_ext:
3058+
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
3059+
else:
3060+
cache_file_prefix = cache_file_ext = None
3061+
30553062
def load_processed_shard_from_cache(shard_kwargs: dict[str, Any]) -> Dataset:
30563063
"""Load a processed shard from cache if it exists, otherwise throw an error."""
30573064
shard = shard_kwargs["shard"]
@@ -3072,39 +3079,19 @@ def pbar_total(num_shards: int, batch_size: Optional[int]) -> int:
30723079
return total // num_shards // batch_size * num_shards * batch_size
30733080
return total
30743081

3075-
def get_existing_cache_file_map(
3076-
cache_file_name: Optional[str],
3077-
) -> dict[int, list[str]]:
3078-
cache_files_by_num_proc: dict[int, list[str]] = defaultdict(list)
3079-
if cache_file_name is None:
3080-
return cache_files_by_num_proc
3082+
existing_cache_file_map: dict[int, list[str]] = defaultdict(list)
3083+
if cache_file_name is not None:
30813084
if os.path.exists(cache_file_name):
3082-
cache_files_by_num_proc[1] = [cache_file_name]
3085+
existing_cache_file_map[1] = [cache_file_name]
30833086

3084-
suffix_pattern_parts: list[str] = []
3085-
for literal_text, field_name, format_spec, _ in string_formatter.parse(suffix_template):
3086-
suffix_pattern_parts.append(re.escape(literal_text))
3087-
if field_name:
3088-
# TODO: we may want to place restrictions on acceptable format_spec or we will fail to match
3089-
# someone's hexidecimal or scientific notation format 😵
3090-
suffix_pattern_parts.append(f"(?P<{field_name}>\\d+)")
3091-
suffix_pattern = "".join(suffix_pattern_parts)
3092-
3093-
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
3094-
if not cache_file_ext:
3095-
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
3096-
3097-
cache_file_pattern = "^" + re.escape(cache_file_prefix) + suffix_pattern + re.escape(cache_file_ext) + "$"
3098-
cache_file_regex = re.compile(cache_file_pattern)
3087+
assert cache_file_prefix is not None and cache_file_ext is not None
3088+
cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext
30993089

31003090
for cache_file in glob.iglob(f"{cache_file_prefix}*{cache_file_ext}"):
3101-
if m := cache_file_regex.match(cache_file):
3102-
file_num_proc = int(m.group("num_proc"))
3103-
cache_files_by_num_proc[file_num_proc].append(cache_file)
3104-
3105-
return cache_files_by_num_proc
3106-
3107-
existing_cache_file_map = get_existing_cache_file_map(cache_file_name)
3091+
suffix_variable_map = string_to_dict(cache_file, cache_file_with_suffix_pattern)
3092+
if suffix_variable_map is not None:
3093+
file_num_proc = int(suffix_variable_map["num_proc"])
3094+
existing_cache_file_map[file_num_proc].append(cache_file)
31083095

31093096
num_shards = num_proc or 1
31103097
if existing_cache_file_map:
@@ -3122,7 +3109,7 @@ def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]:
31223109

31233110
num_shards = min(existing_cache_file_map, key=select_existing_cache_files)
31243111

3125-
existing_cache_files = existing_cache_file_map.get(num_shards, [])
3112+
existing_cache_files = existing_cache_file_map[num_shards]
31263113

31273114
def format_cache_file_name(
31283115
cache_file_name: Optional[str],
@@ -3131,9 +3118,7 @@ def format_cache_file_name(
31313118
if not cache_file_name:
31323119
return cache_file_name
31333120

3134-
cache_file_prefix, cache_file_ext = os.path.splitext(cache_file_name)
3135-
if not cache_file_ext:
3136-
raise ValueError(f"Expected cache_file_name to have an extension, but got: {cache_file_name}")
3121+
assert cache_file_prefix is not None and cache_file_ext is not None
31373122

31383123
if isinstance(rank, int):
31393124
cache_file_name = (
@@ -5835,7 +5820,7 @@ def push_to_hub(
58355820
@transmit_format
58365821
@fingerprint_transform(inplace=False)
58375822
def add_column(
5838-
self, name: str, column: Union[list, np.array], new_fingerprint: str, feature: Optional[FeatureType] = None
5823+
self, name: str, column: Union[list, np.ndarray], new_fingerprint: str, feature: Optional[FeatureType] = None
58395824
):
58405825
"""Add column to Dataset.
58415826

0 commit comments

Comments
 (0)