Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datatrove/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"):

def get_file(self, filename):
"""
Opens file `filename` if it hasn't been opened yet. Otherwise just returns it from the file cache
Opens file `filename` if it hasn't been opened yet. Otherwise, just returns it from the file cache
Args:
filename: name of the file to open/get if previously opened

Expand Down
1 change: 1 addition & 0 deletions src/datatrove/pipeline/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .jsonl import JsonlWriter
from .parquet import ParquetWriter
11 changes: 6 additions & 5 deletions src/datatrove/pipeline/writers/disk_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from abc import ABC, abstractmethod
from string import Template
from typing import Callable
from typing import IO, Callable

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import DataFolderLike, get_datafolder
Expand Down Expand Up @@ -31,6 +31,7 @@ def __init__(
output_filename: str = None,
compression: str | None = "infer",
adapter: Callable = None,
mode: str = "wt",
):
"""
Base writer block to save data to disk.
Expand All @@ -47,7 +48,7 @@ def __init__(
if self.compression == "gzip" and not output_filename.endswith(".gz"):
output_filename += ".gz"
self.output_filename = Template(output_filename)
self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression)
self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression)
self.adapter = adapter if adapter else _default_adapter

def __enter__(self):
Expand Down Expand Up @@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
)

@abstractmethod
def _write(self, document: dict, file_handler):
def _write(self, document: dict, file_handler: IO, filename: str):
"""
Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler`
Args:
document: dictionary with the data to save
file_handler: file_handler where it should be saved

filename: to use as a key for writer helpers and other data
Returns:

"""
Expand All @@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs):

"""
output_filename = self._get_output_filename(document, rank, **kwargs)
self._write(self.adapter(document), self.output_mg.get_file(output_filename))
self._write(self.adapter(document), self.output_mg.get_file(output_filename), output_filename)
self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs))
self.stat_update(StatHints.total)
self.update_doc_stats(document)
Expand Down
4 changes: 2 additions & 2 deletions src/datatrove/pipeline/writers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def __init__(
):
super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter)

def _write(self, document: dict, file: IO):
file.write(json.dumps(document, ensure_ascii=False) + "\n")
def _write(self, document: dict, file_handler: IO, _filename: str):
file_handler.write(json.dumps(document, ensure_ascii=False) + "\n")
55 changes: 55 additions & 0 deletions src/datatrove/pipeline/writers/parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import defaultdict
from typing import IO, Callable

from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter


class ParquetWriter(DiskWriter):
default_output_filename: str = "${rank}.parquet"
name = "📒 Parquet"
_requires_dependencies = ["pyarrow"]

def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = None,
adapter: Callable = None,
batch_size: int = 1000,
):
super().__init__(output_folder, output_filename, compression, adapter, mode="wb")
self._writers = {}
self._batches = defaultdict(list)
self.batch_size = batch_size

def _write_batch(self, filename):
if not self._batches[filename]:
return
import pyarrow as pa

# prepare batch
batch = pa.RecordBatch.from_pylist(self._batches.pop(filename))
# write batch
self._writers[filename].write_batch(batch)

def _write(self, document: dict, file_handler: IO, filename: str):
import pyarrow as pa
import pyarrow.parquet as pq

if filename not in self._writers:
self._writers[filename] = pq.ParquetWriter(
file_handler, schema=pa.RecordBatch.from_pylist([document]).schema
)
self._batches[filename].append(document)
if len(self._batches[filename]) == self.batch_size:
self._write_batch(filename)

def close(self):
for filename in list(self._batches.keys()):
self._write_batch(filename)
for writer in self._writers.values():
writer.close()
self._batches.clear()
self._writers.clear()
super().close()
33 changes: 33 additions & 0 deletions tests/pipeline/test_parquet_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.pipeline.readers.parquet import ParquetReader
from datatrove.pipeline.writers.parquet import ParquetWriter

from ..utils import require_pyarrow


@require_pyarrow
class TestParquetWriter(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_write(self):
data = [
Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"})
for i, text in enumerate(["hello", "text2", "more text"])
]
with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w:
for doc in data:
w.write(doc)
reader = ParquetReader(self.tmp_dir)
c = 0
for read_doc, original in zip(reader(), data):
read_doc.metadata.pop("file_path", None)
assert read_doc == original
c += 1
assert c == len(data)