Skip to content

Commit 94ccd1b

Browse files
authored
Write pdf in map (#7487)
* wriet pdf in map * fix * support lists of images or pdf in map() output
1 parent 18aa0cb commit 94ccd1b

File tree

4 files changed

+54
-26
lines changed

4 files changed

+54
-26
lines changed

src/datasets/arrow_writer.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from fsspec.core import url_to_fs
2626

2727
from . import config
28-
from .features import Audio, Features, Image, Value, Video
28+
from .features import Audio, Features, Image, Pdf, Value, Video
2929
from .features.features import (
3030
FeatureType,
3131
_ArrayXDExtensionType,
@@ -42,7 +42,7 @@
4242
from .keyhash import DuplicatedKeysError, KeyHasher
4343
from .table import array_cast, cast_array_to_feature, embed_table_storage, table_cast
4444
from .utils import logging
45-
from .utils.py_utils import asdict, first_non_null_value
45+
from .utils.py_utils import asdict, first_non_null_non_empty_value
4646

4747

4848
logger = logging.get_logger(__name__)
@@ -189,9 +189,23 @@ def _infer_custom_type_and_encode(data: Iterable) -> tuple[Iterable, Optional[Fe
189189
if config.PIL_AVAILABLE and "PIL" in sys.modules:
190190
import PIL.Image
191191

192-
non_null_idx, non_null_value = first_non_null_value(data)
192+
non_null_idx, non_null_value = first_non_null_non_empty_value(data)
193193
if isinstance(non_null_value, PIL.Image.Image):
194194
return [Image().encode_example(value) if value is not None else None for value in data], Image()
195+
if isinstance(non_null_value, list) and isinstance(non_null_value[0], PIL.Image.Image):
196+
return [[Image().encode_example(x) for x in value] if value is not None else None for value in data], [
197+
Image()
198+
]
199+
if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules:
200+
import pdfplumber
201+
202+
non_null_idx, non_null_value = first_non_null_non_empty_value(data)
203+
if isinstance(non_null_value, pdfplumber.pdf.PDF):
204+
return [Pdf().encode_example(value) if value is not None else None for value in data], Pdf()
205+
if isinstance(non_null_value, list) and isinstance(non_null_value[0], pdfplumber.pdf.PDF):
206+
return [[Pdf().encode_example(x) for x in value] if value is not None else None for value in data], [
207+
Pdf()
208+
]
195209
return data, None
196210

197211
def __arrow_array__(self, type: Optional[pa.DataType] = None):
@@ -222,7 +236,7 @@ def __arrow_array__(self, type: Optional[pa.DataType] = None):
222236
# efficient np array to pyarrow array
223237
if isinstance(data, np.ndarray):
224238
out = numpy_to_pyarrow_listarray(data)
225-
elif isinstance(data, list) and data and isinstance(first_non_null_value(data)[1], np.ndarray):
239+
elif isinstance(data, list) and data and isinstance(first_non_null_non_empty_value(data)[1], np.ndarray):
226240
out = list_of_np_array_to_pyarrow_listarray(data)
227241
else:
228242
trying_cast_to_python_objects = True

src/datasets/features/features.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
4343
from .audio import Audio
4444
from .image import Image, encode_pil_image
45-
from .pdf import Pdf
45+
from .pdf import Pdf, encode_pdfplumber_pdf
4646
from .translation import Translation, TranslationVariableLanguages
4747
from .video import Video
4848

@@ -299,6 +299,9 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
299299
if config.PIL_AVAILABLE and "PIL" in sys.modules:
300300
import PIL.Image
301301

302+
if config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules:
303+
import pdfplumber
304+
302305
if isinstance(obj, np.ndarray):
303306
if obj.ndim == 0:
304307
return obj[()], True
@@ -367,6 +370,8 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
367370
)
368371
elif config.PIL_AVAILABLE and "PIL" in sys.modules and isinstance(obj, PIL.Image.Image):
369372
return encode_pil_image(obj), True
373+
elif config.PDFPLUMBER_AVAILABLE and "pdfplumber" in sys.modules and isinstance(obj, pdfplumber.pdf.PDF):
374+
return encode_pdfplumber_pdf(obj), True
370375
elif isinstance(obj, pd.Series):
371376
return (
372377
_cast_to_python_objects(

src/datasets/features/pdf.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -
9696
return {"path": None, "bytes": value}
9797
elif pdfplumber is not None and isinstance(value, pdfplumber.pdf.PDF):
9898
# convert the pdfplumber.pdf.PDF to bytes
99-
return self.encode_pdfplumber_pdf(value)
99+
return encode_pdfplumber_pdf(value)
100100
elif value.get("path") is not None and os.path.isfile(value["path"]):
101101
# we set "bytes": None to not duplicate the data if they're already available locally
102102
return {"bytes": None, "path": value.get("path")}
@@ -108,26 +108,6 @@ def encode_example(self, value: Union[str, bytes, dict, "pdfplumber.pdf.PDF"]) -
108108
f"A pdf sample should have one of 'path' or 'bytes' but they are missing or None in {value}."
109109
)
110110

111-
def encode_pdfplumber_pdf(pdf: "pdfplumber.pdf.PDF") -> dict:
112-
"""
113-
Encode a pdfplumber.pdf.PDF object into a dictionary.
114-
115-
If the PDF has an associated file path, returns the path. Otherwise, serializes
116-
the PDF content into bytes.
117-
118-
Args:
119-
pdf (pdfplumber.pdf.PDF): A pdfplumber PDF object.
120-
121-
Returns:
122-
dict: A dictionary with "path" or "bytes" field.
123-
"""
124-
if hasattr(pdf, "stream") and hasattr(pdf.stream, "name") and pdf.stream.name:
125-
# Return the path if the PDF has an associated file path
126-
return {"path": pdf.stream.name, "bytes": None}
127-
else:
128-
# Convert the PDF to bytes if no path is available
129-
return {"path": None, "bytes": pdf_to_bytes(pdf)}
130-
131111
def decode_example(self, value: dict, token_per_repo_id=None) -> "pdfplumber.pdf.PDF":
132112
"""Decode example pdf file into pdf data.
133113
@@ -235,3 +215,24 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
235215
path_array = pa.array([None] * len(storage), type=pa.string())
236216
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
237217
return array_cast(storage, self.pa_type)
218+
219+
220+
def encode_pdfplumber_pdf(pdf: "pdfplumber.pdf.PDF") -> dict:
221+
"""
222+
Encode a pdfplumber.pdf.PDF object into a dictionary.
223+
224+
If the PDF has an associated file path, returns the path. Otherwise, serializes
225+
the PDF content into bytes.
226+
227+
Args:
228+
pdf (pdfplumber.pdf.PDF): A pdfplumber PDF object.
229+
230+
Returns:
231+
dict: A dictionary with "path" or "bytes" field.
232+
"""
233+
if hasattr(pdf, "stream") and hasattr(pdf.stream, "name") and pdf.stream.name:
234+
# Return the path if the PDF has an associated file path
235+
return {"path": pdf.stream.name, "bytes": None}
236+
else:
237+
# Convert the PDF to bytes if no path is available
238+
return {"path": None, "bytes": pdf_to_bytes(pdf)}

src/datasets/utils/py_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,14 @@ def first_non_null_value(iterable):
321321
return -1, None
322322

323323

324+
def first_non_null_non_empty_value(iterable):
325+
"""Return the index and the value of the first non-null non-empty value in the iterable. If all values are None or empty, return -1 as index."""
326+
for i, value in enumerate(iterable):
327+
if value is not None and not (isinstance(value, (dict, list)) and len(value) == 0):
328+
return i, value
329+
return -1, None
330+
331+
324332
def zip_dict(*dicts):
325333
"""Iterate over items of dictionaries grouped by their keys."""
326334
for key in unique_values(itertools.chain(*dicts)): # set merge all keys

0 commit comments

Comments
 (0)