-
Notifications
You must be signed in to change notification settings - Fork 230
Adds parquet writer #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds parquet writer #103
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1,2 @@ | ||
| from .jsonl import JsonlWriter | ||
| from .parquet import ParquetWriter |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| from collections import defaultdict | ||
| from typing import IO, Callable | ||
|
|
||
| from datatrove.io import DataFolderLike | ||
| from datatrove.pipeline.writers.disk_base import DiskWriter | ||
|
|
||
|
|
||
| class ParquetWriter(DiskWriter): | ||
| default_output_filename: str = "${rank}.parquet" | ||
| name = "📒 Parquet" | ||
| _requires_dependencies = ["pyarrow"] | ||
|
|
||
| def __init__( | ||
| self, | ||
| output_folder: DataFolderLike, | ||
| output_filename: str = None, | ||
| compression: str | None = None, | ||
| adapter: Callable = None, | ||
| batch_size: int = 1000, | ||
| ): | ||
| super().__init__(output_folder, output_filename, compression, adapter, mode="wb") | ||
| self._writers = {} | ||
| self._batches = defaultdict(list) | ||
| self.batch_size = batch_size | ||
|
|
||
| def _write_batch(self, filename): | ||
| if not self._batches[filename]: | ||
| return | ||
| import pyarrow as pa | ||
|
|
||
| names = list(self._batches[filename][0].keys()) | ||
| # prepare batch | ||
| batch = pa.record_batch(list(zip(*[d.values() for d in self._batches.pop(filename)])), names=names) | ||
| # write batch | ||
| self._writers[filename].write_batch(batch) | ||
|
|
||
| def _write(self, document: dict, file_handler: IO, filename: str): | ||
| import pyarrow as pa | ||
| import pyarrow.parquet as pq | ||
|
|
||
| if filename not in self._writers: | ||
| self._writers[filename] = pq.ParquetWriter( | ||
| file_handler, schema=pa.table({name: [val] for name, val in document.items()}).schema | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Parquet still doesn't support unions (see apache/parquet-format#44), so we would have to work around this limitation by turning the But the current solution is good enough, so this can also be addressed later. PS: To be extra strict, the default nullability of non-nullable fields ("text", "id", etc.) in the above schema can be disabled with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They used to have fixed types but now we support an Regarding unions, does this mean if we have different value types in Regarding nullability, the problem would also be the custom user formats There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we could also have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We could only use the fixed schema if
JSON supports these types, so it will work.
Yes, this would be cleaner indeed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I think maybe for now we will keep the current format so that even when people upload to the hub directly and so on there isn't a big json field |
||
| self._batches[filename].append(document) | ||
| if len(self._batches[filename]) == self.batch_size: | ||
| self._write_batch(filename) | ||
|
|
||
| def close(self): | ||
| for filename in list(self._batches.keys()): | ||
| self._write_batch(filename) | ||
| for writer in self._writers.values(): | ||
| writer.close() | ||
| self._batches.clear() | ||
| self._writers.clear() | ||
| super().close() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| import shutil | ||
| import tempfile | ||
| import unittest | ||
|
|
||
| from datatrove.data import Document | ||
| from datatrove.pipeline.readers.parquet import ParquetReader | ||
| from datatrove.pipeline.writers.parquet import ParquetWriter | ||
|
|
||
| from ..utils import require_pyarrow | ||
|
|
||
|
|
||
| @require_pyarrow | ||
| class TestParquetWriter(unittest.TestCase): | ||
| def setUp(self): | ||
| # Create a temporary directory | ||
| self.tmp_dir = tempfile.mkdtemp() | ||
| self.addCleanup(shutil.rmtree, self.tmp_dir) | ||
|
|
||
| def test_write(self): | ||
| data = [ | ||
| Document(text=text, id=str(i), metadata={"somedata": 2 * i}) | ||
| for i, text in enumerate(["hello", "text2", "more text"]) | ||
| ] | ||
| with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w: | ||
| for doc in data: | ||
| w.write(doc) | ||
| reader = ParquetReader(self.tmp_dir) | ||
| c = 0 | ||
| for read_doc, original in zip(reader(), data): | ||
| read_doc.metadata.pop("file_path", None) | ||
| assert read_doc == original | ||
| c += 1 | ||
| assert c == len(data) |
Uh oh!
There was an error while loading. Please reload this page.