@@ -3052,6 +3052,13 @@ def map(
30523052 cache_file_name = self ._get_cache_file_path (new_fingerprint )
30533053 dataset_kwargs ["cache_file_name" ] = cache_file_name
30543054
3055+ if cache_file_name is not None :
3056+ cache_file_prefix , cache_file_ext = os .path .splitext (cache_file_name )
3057+ if not cache_file_ext :
3058+ raise ValueError (f"Expected cache_file_name to have an extension, but got: { cache_file_name } " )
3059+ else :
3060+ cache_file_prefix = cache_file_ext = None
3061+
30553062 def load_processed_shard_from_cache (shard_kwargs : dict [str , Any ]) -> Dataset :
30563063 """Load a processed shard from cache if it exists, otherwise throw an error."""
30573064 shard = shard_kwargs ["shard" ]
@@ -3072,39 +3079,19 @@ def pbar_total(num_shards: int, batch_size: Optional[int]) -> int:
30723079 return total // num_shards // batch_size * num_shards * batch_size
30733080 return total
30743081
3075- def get_existing_cache_file_map (
3076- cache_file_name : Optional [str ],
3077- ) -> dict [int , list [str ]]:
3078- cache_files_by_num_proc : dict [int , list [str ]] = defaultdict (list )
3079- if cache_file_name is None :
3080- return cache_files_by_num_proc
3082+ existing_cache_file_map : dict [int , list [str ]] = defaultdict (list )
3083+ if cache_file_name is not None :
30813084 if os .path .exists (cache_file_name ):
3082- cache_files_by_num_proc [1 ] = [cache_file_name ]
3085+ existing_cache_file_map [1 ] = [cache_file_name ]
30833086
3084- suffix_pattern_parts : list [str ] = []
3085- for literal_text , field_name , format_spec , _ in string_formatter .parse (suffix_template ):
3086- suffix_pattern_parts .append (re .escape (literal_text ))
3087- if field_name :
3088- # TODO: we may want to place restrictions on acceptable format_spec or we will fail to match
3089- # someone's hexidecimal or scientific notation format 😵
3090- suffix_pattern_parts .append (f"(?P<{ field_name } >\\ d+)" )
3091- suffix_pattern = "" .join (suffix_pattern_parts )
3092-
3093- cache_file_prefix , cache_file_ext = os .path .splitext (cache_file_name )
3094- if not cache_file_ext :
3095- raise ValueError (f"Expected cache_file_name to have an extension, but got: { cache_file_name } " )
3096-
3097- cache_file_pattern = "^" + re .escape (cache_file_prefix ) + suffix_pattern + re .escape (cache_file_ext ) + "$"
3098- cache_file_regex = re .compile (cache_file_pattern )
3087+ assert cache_file_prefix is not None and cache_file_ext is not None
3088+ cache_file_with_suffix_pattern = cache_file_prefix + suffix_template + cache_file_ext
30993089
31003090 for cache_file in glob .iglob (f"{ cache_file_prefix } *{ cache_file_ext } " ):
3101- if m := cache_file_regex .match (cache_file ):
3102- file_num_proc = int (m .group ("num_proc" ))
3103- cache_files_by_num_proc [file_num_proc ].append (cache_file )
3104-
3105- return cache_files_by_num_proc
3106-
3107- existing_cache_file_map = get_existing_cache_file_map (cache_file_name )
3091+ suffix_variable_map = string_to_dict (cache_file , cache_file_with_suffix_pattern )
3092+ if suffix_variable_map is not None :
3093+ file_num_proc = int (suffix_variable_map ["num_proc" ])
3094+ existing_cache_file_map [file_num_proc ].append (cache_file )
31083095
31093096 num_shards = num_proc or 1
31103097 if existing_cache_file_map :
@@ -3122,7 +3109,7 @@ def select_existing_cache_files(mapped_num_proc: int) -> tuple[float, ...]:
31223109
31233110 num_shards = min (existing_cache_file_map , key = select_existing_cache_files )
31243111
3125- existing_cache_files = existing_cache_file_map . get ( num_shards , [])
3112+ existing_cache_files = existing_cache_file_map [ num_shards ]
31263113
31273114 def format_cache_file_name (
31283115 cache_file_name : Optional [str ],
@@ -3131,9 +3118,7 @@ def format_cache_file_name(
31313118 if not cache_file_name :
31323119 return cache_file_name
31333120
3134- cache_file_prefix , cache_file_ext = os .path .splitext (cache_file_name )
3135- if not cache_file_ext :
3136- raise ValueError (f"Expected cache_file_name to have an extension, but got: { cache_file_name } " )
3121+ assert cache_file_prefix is not None and cache_file_ext is not None
31373122
31383123 if isinstance (rank , int ):
31393124 cache_file_name = (
@@ -5835,7 +5820,7 @@ def push_to_hub(
58355820 @transmit_format
58365821 @fingerprint_transform (inplace = False )
58375822 def add_column (
5838- self , name : str , column : Union [list , np .array ], new_fingerprint : str , feature : Optional [FeatureType ] = None
5823+ self , name : str , column : Union [list , np .ndarray ], new_fingerprint : str , feature : Optional [FeatureType ] = None
58395824 ):
58405825 """Add column to Dataset.
58415826
0 commit comments