Skip to content

Commit 7a21e75

Browse files
committed
minor
1 parent 9c716d5 commit 7a21e75

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

src/nlp/arrow_dataset.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,28 +1303,17 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
13031303

13041304
try:
13051305
# Loop over single examples or batches and write to buffer/file if examples are to be updated
1306+
pbar_iterable = self if not batched else range(0, len(self), batch_size)
1307+
pbar_desc = "#" + str(rank) if rank is not None else None
1308+
pbar = tqdm(pbar_iterable, disable=not_verbose, position=rank, unit="ba", desc=pbar_desc)
13061309
if not batched:
1307-
for i, example in enumerate(
1308-
tqdm(
1309-
self,
1310-
disable=not_verbose,
1311-
position=rank,
1312-
unit="ex",
1313-
desc="#" + str(rank) if rank is not None else None,
1314-
)
1315-
):
1310+
for i, example in enumerate(pbar):
13161311
example = apply_function_on_filtered_inputs(example, i)
13171312
if update_data:
13181313
example = cast_to_python_objects(example)
13191314
writer.write(example)
13201315
else:
1321-
for i in tqdm(
1322-
range(0, len(self), batch_size),
1323-
disable=not_verbose,
1324-
position=rank,
1325-
unit="ba",
1326-
desc="#" + str(rank) if rank is not None else None,
1327-
):
1316+
for i in pbar:
13281317
if drop_last_batch and i + batch_size > self.num_rows:
13291318
continue
13301319
batch = self[i : i + batch_size]
@@ -1561,12 +1550,16 @@ def select(
15611550
if keep_in_memory or indices_cache_file_name is None:
15621551
buf_writer = pa.BufferOutputStream()
15631552
tmp_file = None
1564-
writer = ArrowWriter(stream=buf_writer, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint)
1553+
writer = ArrowWriter(
1554+
stream=buf_writer, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint, unit="indices"
1555+
)
15651556
else:
15661557
buf_writer = None
15671558
logger.info("Caching indices mapping at %s", indices_cache_file_name)
15681559
tmp_file = tempfile.NamedTemporaryFile("wb", dir=os.path.dirname(indices_cache_file_name), delete=False)
1569-
writer = ArrowWriter(path=tmp_file.name, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint)
1560+
writer = ArrowWriter(
1561+
path=tmp_file.name, writer_batch_size=writer_batch_size, fingerprint=new_fingerprint, unit="indices"
1562+
)
15701563

15711564
indices_array = pa.array(indices, type=pa.uint64())
15721565
# Check if we need to convert indices

src/nlp/arrow_writer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
disable_nullable: bool = False,
135135
update_features: bool = False,
136136
with_metadata: bool = True,
137+
unit: str = "examples",
137138
):
138139
if path is None and stream is None:
139140
raise ValueError("At least one of path and stream must be provided.")
@@ -161,6 +162,7 @@ def __init__(
161162
self.writer_batch_size = writer_batch_size or DEFAULT_MAX_BATCH_SIZE
162163
self.update_features = update_features
163164
self.with_metadata = with_metadata
165+
self.unit = unit
164166

165167
self._num_examples = 0
166168
self._num_bytes = 0
@@ -290,8 +292,9 @@ def finalize(self, close_stream=True):
290292
if close_stream:
291293
self.stream.close()
292294
logger.info(
293-
"Done writing %s examples in %s bytes %s.",
295+
"Done writing %s %s in %s bytes %s.",
294296
self._num_examples,
297+
self.unit,
295298
self._num_bytes,
296299
self._path if self._path else "",
297300
)

0 commit comments

Comments
 (0)