@@ -1303,28 +1303,17 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
13031303
13041304 try :
13051305 # Loop over single examples or batches and write to buffer/file if examples are to be updated
1306+ pbar_iterable = self if not batched else range (0 , len (self ), batch_size )
1307+ pbar_desc = "#" + str (rank ) if rank is not None else None
1308+ pbar = tqdm (pbar_iterable , disable = not_verbose , position = rank , unit = "ba" , desc = pbar_desc )
13061309 if not batched :
1307- for i , example in enumerate (
1308- tqdm (
1309- self ,
1310- disable = not_verbose ,
1311- position = rank ,
1312- unit = "ex" ,
1313- desc = "#" + str (rank ) if rank is not None else None ,
1314- )
1315- ):
1310+ for i , example in enumerate (pbar ):
13161311 example = apply_function_on_filtered_inputs (example , i )
13171312 if update_data :
13181313 example = cast_to_python_objects (example )
13191314 writer .write (example )
13201315 else :
1321- for i in tqdm (
1322- range (0 , len (self ), batch_size ),
1323- disable = not_verbose ,
1324- position = rank ,
1325- unit = "ba" ,
1326- desc = "#" + str (rank ) if rank is not None else None ,
1327- ):
1316+ for i in pbar :
13281317 if drop_last_batch and i + batch_size > self .num_rows :
13291318 continue
13301319 batch = self [i : i + batch_size ]
@@ -1561,12 +1550,16 @@ def select(
15611550 if keep_in_memory or indices_cache_file_name is None :
15621551 buf_writer = pa .BufferOutputStream ()
15631552 tmp_file = None
1564- writer = ArrowWriter (stream = buf_writer , writer_batch_size = writer_batch_size , fingerprint = new_fingerprint )
1553+ writer = ArrowWriter (
1554+ stream = buf_writer , writer_batch_size = writer_batch_size , fingerprint = new_fingerprint , unit = "indices"
1555+ )
15651556 else :
15661557 buf_writer = None
15671558 logger .info ("Caching indices mapping at %s" , indices_cache_file_name )
15681559 tmp_file = tempfile .NamedTemporaryFile ("wb" , dir = os .path .dirname (indices_cache_file_name ), delete = False )
1569- writer = ArrowWriter (path = tmp_file .name , writer_batch_size = writer_batch_size , fingerprint = new_fingerprint )
1560+ writer = ArrowWriter (
1561+ path = tmp_file .name , writer_batch_size = writer_batch_size , fingerprint = new_fingerprint , unit = "indices"
1562+ )
15701563
15711564 indices_array = pa .array (indices , type = pa .uint64 ())
15721565 # Check if we need to convert indices
0 commit comments