Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datatrove/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"):

def get_file(self, filename):
"""
Opens file `filename` if it hasn't been opened yet. Otherwise just returns it from the file cache
Opens file `filename` if it hasn't been opened yet. Otherwise, just returns it from the file cache
Args:
filename: name of the file to open/get if previously opened

Expand Down
1 change: 1 addition & 0 deletions src/datatrove/pipeline/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .jsonl import JsonlWriter
from .parquet import ParquetWriter
11 changes: 6 additions & 5 deletions src/datatrove/pipeline/writers/disk_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from abc import ABC, abstractmethod
from string import Template
from typing import Callable
from typing import IO, Callable

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import DataFolderLike, get_datafolder
Expand Down Expand Up @@ -31,6 +31,7 @@ def __init__(
output_filename: str = None,
compression: str | None = "infer",
adapter: Callable = None,
mode: str = "wt",
):
"""
Base writer block to save data to disk.
Expand All @@ -47,7 +48,7 @@ def __init__(
if self.compression == "gzip" and not output_filename.endswith(".gz"):
output_filename += ".gz"
self.output_filename = Template(output_filename)
self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression)
self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression)
self.adapter = adapter if adapter else _default_adapter

def __enter__(self):
Expand Down Expand Up @@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
)

@abstractmethod
def _write(self, document: dict, file_handler):
def _write(self, document: dict, file_handler: IO, filename: str):
"""
Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler`
Args:
document: dictionary with the data to save
file_handler: file_handler where it should be saved

filename: to use as a key for writer helpers and other data
Returns:

"""
Expand All @@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs):

"""
output_filename = self._get_output_filename(document, rank, **kwargs)
self._write(self.adapter(document), self.output_mg.get_file(output_filename))
self._write(self.adapter(document), self.output_mg.get_file(output_filename), output_filename)
self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs))
self.stat_update(StatHints.total)
self.update_doc_stats(document)
Expand Down
4 changes: 2 additions & 2 deletions src/datatrove/pipeline/writers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def __init__(
):
super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter)

def _write(self, document: dict, file: IO):
file.write(json.dumps(document, ensure_ascii=False) + "\n")
def _write(self, document: dict, file_handler: IO, _filename: str):
file_handler.write(json.dumps(document, ensure_ascii=False) + "\n")
56 changes: 56 additions & 0 deletions src/datatrove/pipeline/writers/parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from collections import defaultdict
from typing import IO, Callable

from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter


class ParquetWriter(DiskWriter):
default_output_filename: str = "${rank}.parquet"
name = "📒 Parquet"
_requires_dependencies = ["pyarrow"]

def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = None,
adapter: Callable = None,
batch_size: int = 1000,
):
super().__init__(output_folder, output_filename, compression, adapter, mode="wb")
self._writers = {}
self._batches = defaultdict(list)
self.batch_size = batch_size

def _write_batch(self, filename):
if not self._batches[filename]:
return
import pyarrow as pa

names = list(self._batches[filename][0].keys())
# prepare batch
batch = pa.record_batch(list(zip(*[d.values() for d in self._batches.pop(filename)])), names=names)
# write batch
self._writers[filename].write_batch(batch)

def _write(self, document: dict, file_handler: IO, filename: str):
import pyarrow as pa
import pyarrow.parquet as pq

if filename not in self._writers:
self._writers[filename] = pq.ParquetWriter(
file_handler, schema=pa.table({name: [val] for name, val in document.items()}).schema
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Document's attributes have fixed types, so I wonder if it would make more sense to pass pa.schema({"text": pa.string(), "id": pa.string(), media: pa.struct({"type": pa.int32(), "url": pa.string(), "alt": pa.string(), "local_path": pa.string()}), "metadata": pa.string()}) for the schema.

Parquet still doesn't support unions (see apache/parquet-format#44), so we would have to work around this limitation by turning the metadata value into a string using json.dumps(metadata). Then, to make the ParquetReader compatible with this format, we would also have to add metadata to the schema (pa.schema(fields, metadata=...)), which the reader would check and perform deserialization (using json.loads) on the other side if needed.

But the current solution is good enough, so this can also be addressed later.

PS: To be extra strict, the default nullability of non-nullable fields ("text", "id", etc.) in the above schema can be disabled with pa.field(pa_type, nullable=False)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They used to have fixed types but now we support an adapter so that people can choose their output format (still a dictionary, but they can do whatever they want with the fields)

Regarding unions, does this mean if we have different value types in metadata (let's say strings and floats) then this doesn't work?

Regarding nullability, the problem would also be the custom user formats

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could also have pa.RecordBatch.from_pylist([document]).schema here instead?

Copy link
Contributor

@mariosasko mariosasko Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They used to have fixed types but now we support an adapter so that people can choose their output format (still a dictionary, but they can do whatever they want with the fields)

We could only use the fixed schema if adapter is not specified.

Regarding unions, does this mean if we have different value types in metadata (let's say strings and floats) then this doesn't work?

JSON supports these types, so it will work.

maybe we could also have pa.RecordBatch.from_pylist([document]).schema here instead?

Yes, this would be cleaner indeed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I think maybe for now we will keep the current format so that even when people upload to the hub directly and so on there isn't a big json field

self._batches[filename].append(document)
if len(self._batches[filename]) == self.batch_size:
self._write_batch(filename)

def close(self):
for filename in list(self._batches.keys()):
self._write_batch(filename)
for writer in self._writers.values():
writer.close()
self._batches.clear()
self._writers.clear()
super().close()
33 changes: 33 additions & 0 deletions tests/pipeline/test_parquet_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.pipeline.readers.parquet import ParquetReader
from datatrove.pipeline.writers.parquet import ParquetWriter

from ..utils import require_pyarrow


@require_pyarrow
class TestParquetWriter(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_write(self):
data = [
Document(text=text, id=str(i), metadata={"somedata": 2 * i})
for i, text in enumerate(["hello", "text2", "more text"])
]
with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w:
for doc in data:
w.write(doc)
reader = ParquetReader(self.tmp_dir)
c = 0
for read_doc, original in zip(reader(), data):
read_doc.metadata.pop("file_path", None)
assert read_doc == original
c += 1
assert c == len(data)