1
1
from os import PathLike
2
2
from pathlib import Path
3
- from typing import Dict , Any , Optional , Type
3
+ from typing import Dict , Any , Optional , Type , List , Callable
4
4
import base64
5
5
from io import BytesIO
6
6
from PIL import Image
7
7
from torch .utils .data import Dataset
8
8
from pypdf import PdfReader
9
9
from tqdm import tqdm
10
10
from dataclasses import dataclass , fields
11
+ from abc import ABC , abstractmethod
11
12
12
13
from olmocr .data .renderpdf import render_pdf_to_base64png
13
14
@@ -35,71 +36,22 @@ def __post_init__(self):
35
36
raise TypeError ("is_table must be of type bool." )
36
37
if not isinstance (self .is_diagram , bool ):
37
38
raise TypeError ("is_diagram must be of type bool." )
38
-
39
39
40
- class MarkdownPDFDocumentDataset ( Dataset ):
41
- def __init__ ( self , root_dir : str | PathLike , target_longest_image_dim : int , image_transform = None , front_matter_class = None ):
42
- """
43
- Initialize the dataset by finding all markdown files with corresponding PDFs.
44
-
45
- Args :
46
- root_dir: Path to the root folder containing processed markdown and PDF files
47
- target_longest_image_dim: Target dimension for the longest side of the image
48
- image_transform: Optional transform to apply to the PDF images
49
- front_matter_class: Optional dataclass type to validate front matter against
50
- """
51
- self . root_dir = Path ( root_dir )
52
- self . image_transform = image_transform
53
- self . target_longest_image_dim = target_longest_image_dim
40
+
41
+ class PipelineStep ( ABC ):
42
+ """Abstract base class for pipeline steps. """
43
+
44
+ @ abstractmethod
45
+ def process ( self , sample : Dict [ str , Any ]) -> Dict [ str , Any ] :
46
+ """Process a sample and return the modified sample."""
47
+ pass
48
+
49
+
50
+ class FrontMatterParser ( PipelineStep ):
51
+ """Pipeline step that parses front matter from markdown content."""
52
+
53
+ def __init__ ( self , front_matter_class : Optional [ Type ] = None ):
54
54
self .front_matter_class = front_matter_class
55
- self .samples = []
56
-
57
- # Find all markdown files recursively
58
- print (f"Scanning for markdown files in { self .root_dir } ..." )
59
- md_files = list (self .root_dir .rglob ("*.md" ))
60
-
61
- # Verify each markdown file has a corresponding PDF
62
- valid_count = 0
63
- invalid_pdfs = []
64
-
65
- print (f"Validating { len (md_files )} markdown-PDF pairs..." )
66
- for md_path in tqdm (md_files , desc = "Validating PDFs" ):
67
- # Look for PDF with same stem (filename without extension)
68
- pdf_path = md_path .with_suffix ('.pdf' )
69
-
70
- if pdf_path .exists () or pdf_path .is_symlink ():
71
- # Resolve symlink if it is one
72
- if pdf_path .is_symlink ():
73
- pdf_path = pdf_path .resolve ()
74
-
75
- # Verify the resolved path exists
76
- if pdf_path .exists ():
77
- # Validate PDF - check it loads and has exactly one page
78
- try :
79
- reader = PdfReader (str (pdf_path ))
80
- num_pages = len (reader .pages )
81
-
82
- if num_pages != 1 :
83
- invalid_pdfs .append ((pdf_path , f"Expected 1 page, found { num_pages } " ))
84
- continue
85
-
86
- self .samples .append ({
87
- 'markdown_path' : md_path ,
88
- 'pdf_path' : pdf_path
89
- })
90
- valid_count += 1
91
-
92
- except Exception as e :
93
- invalid_pdfs .append ((pdf_path , f"Failed to load: { str (e )} " ))
94
-
95
- print (f"Found { valid_count } valid markdown-PDF pairs" )
96
-
97
- if invalid_pdfs :
98
- print (f"\n Warning: { len (invalid_pdfs )} invalid PDFs found:" )
99
- for pdf_path , reason in invalid_pdfs [:5 ]: # Show first 5
100
- print (f" - { pdf_path .name } : { reason } " )
101
- if len (invalid_pdfs ) > 5 :
102
- print (f" ... and { len (invalid_pdfs ) - 5 } more" )
103
55
104
56
def _extract_front_matter_and_text (self , markdown_content : str ) -> tuple [str , str ]:
105
57
"""Extract raw front matter string and text from markdown content."""
@@ -165,6 +117,119 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
165
117
166
118
return self .front_matter_class (** kwargs )
167
119
120
+ def process (self , sample : Dict [str , Any ]) -> Dict [str , Any ]:
121
+ """Parse front matter from markdown content."""
122
+ # Read markdown content if not already loaded
123
+ if 'markdown_content' not in sample :
124
+ sample ['markdown_content' ] = sample ['markdown_path' ].read_text (encoding = 'utf-8' )
125
+
126
+ # Extract and parse front matter
127
+ front_matter_str , text = self ._extract_front_matter_and_text (sample ['markdown_content' ])
128
+ front_matter = self ._parse_front_matter_string (front_matter_str )
129
+
130
+ # Parse front matter to dataclass if specified
131
+ try :
132
+ parsed_front_matter = self ._parse_front_matter (front_matter )
133
+ except Exception as e :
134
+ raise ValueError (f"Error parsing front matter for { sample ['markdown_path' ]} : { e } " )
135
+
136
+ # Update sample
137
+ sample ['text' ] = text
138
+ sample ['front_matter' ] = parsed_front_matter
139
+
140
+ return sample
141
+
142
+
143
+ class PDFRenderer (PipelineStep ):
144
+ """Pipeline step that renders PDF to image."""
145
+
146
+ def __init__ (self , target_longest_image_dim : int , image_transform : Optional [Callable ] = None ):
147
+ self .target_longest_image_dim = target_longest_image_dim
148
+ self .image_transform = image_transform
149
+
150
+ def process (self , sample : Dict [str , Any ]) -> Dict [str , Any ]:
151
+ """Render PDF to image."""
152
+ # Render PDF to image
153
+ base64_png = render_pdf_to_base64png (
154
+ str (sample ['pdf_path' ]),
155
+ page_num = 1 ,
156
+ target_longest_image_dim = self .target_longest_image_dim
157
+ )
158
+ png_bytes = base64 .b64decode (base64_png )
159
+ image = Image .open (BytesIO (png_bytes ))
160
+
161
+ # Apply transform if provided
162
+ if self .image_transform :
163
+ image = self .image_transform (image )
164
+
165
+ # Update sample
166
+ sample ['image' ] = image
167
+
168
+ return sample
169
+
170
+
171
+ class BaseMarkdownPDFDataset (Dataset ):
172
+ """Base dataset class that loads and verifies markdown-PDF pairs."""
173
+
174
+ def __init__ (self , root_dir : str | PathLike , pipeline_steps : Optional [List [PipelineStep ]] = None ):
175
+ """
176
+ Initialize the dataset by finding all markdown files with corresponding PDFs.
177
+
178
+ Args:
179
+ root_dir: Path to the root folder containing processed markdown and PDF files
180
+ pipeline_steps: Optional list of pipeline steps to apply to each sample
181
+ """
182
+ self .root_dir = Path (root_dir )
183
+ self .pipeline_steps = pipeline_steps or []
184
+ self .samples = []
185
+
186
+ # Find all markdown files recursively
187
+ print (f"Scanning for markdown files in { self .root_dir } ..." )
188
+ md_files = list (self .root_dir .rglob ("*.md" ))
189
+
190
+ # Verify each markdown file has a corresponding PDF
191
+ valid_count = 0
192
+ invalid_pdfs = []
193
+
194
+ print (f"Validating { len (md_files )} markdown-PDF pairs..." )
195
+ for md_path in tqdm (md_files , desc = "Validating PDFs" ):
196
+ # Look for PDF with same stem (filename without extension)
197
+ pdf_path = md_path .with_suffix ('.pdf' )
198
+
199
+ if pdf_path .exists () or pdf_path .is_symlink ():
200
+ # Resolve symlink if it is one
201
+ if pdf_path .is_symlink ():
202
+ pdf_path = pdf_path .resolve ()
203
+
204
+ # Verify the resolved path exists
205
+ if pdf_path .exists ():
206
+ # Validate PDF - check it loads and has exactly one page
207
+ try :
208
+ reader = PdfReader (str (pdf_path ))
209
+ num_pages = len (reader .pages )
210
+
211
+ if num_pages != 1 :
212
+ invalid_pdfs .append ((pdf_path , f"Expected 1 page, found { num_pages } " ))
213
+ continue
214
+
215
+ self .samples .append ({
216
+ 'markdown_path' : md_path ,
217
+ 'pdf_path' : pdf_path
218
+ })
219
+ valid_count += 1
220
+
221
+ except Exception as e :
222
+ invalid_pdfs .append ((pdf_path , f"Failed to load: { str (e )} " ))
223
+
224
+ print (f"Found { valid_count } valid markdown-PDF pairs" )
225
+
226
+ if invalid_pdfs :
227
+ print (f"\n Warning: { len (invalid_pdfs )} invalid PDFs found:" )
228
+ for pdf_path , reason in invalid_pdfs [:5 ]: # Show first 5
229
+ print (f" - { pdf_path .name } : { reason } " )
230
+ if len (invalid_pdfs ) > 5 :
231
+ print (f" ... and { len (invalid_pdfs ) - 5 } more" )
232
+
168
233
def __len__ (self ) -> int :
169
234
return len (self .samples )
170
235
@@ -173,40 +238,43 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
173
238
Get a single sample from the dataset.
174
239
175
240
Returns:
176
- dict containing:
177
- - 'image ': PIL Image of the rendered PDF page
241
+ dict containing at minimum :
242
+ - 'markdown_path ': Path to the markdown file
178
243
- 'pdf_path': Path to the PDF file
179
- - 'text': Text content without front matter
180
- - 'front_matter': Dict with parsed front matter
244
+
245
+ Additional fields will be added by pipeline steps.
181
246
"""
182
- sample = self .samples [idx ]
247
+ # Start with basic sample info
248
+ sample = self .samples [idx ].copy ()
183
249
184
- # Read and parse markdown file
185
- markdown_content = sample ['markdown_path' ].read_text (encoding = 'utf-8' )
186
- front_matter_str , text = self ._extract_front_matter_and_text (markdown_content )
187
- front_matter = self ._parse_front_matter_string (front_matter_str )
250
+ # Apply pipeline steps
251
+ for step in self .pipeline_steps :
252
+ sample = step .process (sample )
188
253
189
- # Render PDF to image
190
- base64_png = render_pdf_to_base64png (str (sample ['pdf_path' ]), page_num = 1 , target_longest_image_dim = self .target_longest_image_dim )
191
- png_bytes = base64 .b64decode (base64_png )
192
- image = Image .open (BytesIO (png_bytes ))
193
-
194
- # Apply transform if provided
195
- if self .image_transform :
196
- image = self .image_transform (image )
254
+ return sample
255
+
256
+
257
+ class MarkdownPDFDocumentDataset (BaseMarkdownPDFDataset ):
258
+ """Dataset that includes front matter parsing and PDF rendering by default."""
259
+
260
+ def __init__ (self , root_dir : str | PathLike , target_longest_image_dim : int , image_transform = None , front_matter_class = None ):
261
+ """
262
+ Initialize the dataset with default pipeline steps.
197
263
198
- # Parse front matter to dataclass if specified
199
- try :
200
- parsed_front_matter = self ._parse_front_matter (front_matter )
201
- except Exception as e :
202
- raise ValueError (f"Error parsing front matter for { sample ['markdown_path' ]} : { e } " )
264
+ Args:
265
+ root_dir: Path to the root folder containing processed markdown and PDF files
266
+ target_longest_image_dim: Target dimension for the longest side of the image
267
+ image_transform: Optional transform to apply to the PDF images
268
+ front_matter_class: Optional dataclass type to validate front matter against
269
+ """
270
+ # Create default pipeline steps
271
+ pipeline_steps = [
272
+ FrontMatterParser (front_matter_class ),
273
+ PDFRenderer (target_longest_image_dim , image_transform )
274
+ ]
203
275
204
- return {
205
- 'image' : image ,
206
- 'pdf_path' : str (sample ['pdf_path' ]),
207
- 'text' : text ,
208
- 'front_matter' : parsed_front_matter
209
- }
276
+ # Initialize base class with pipeline
277
+ super ().__init__ (root_dir , pipeline_steps )
210
278
211
279
212
280
if __name__ == "__main__" :
@@ -222,11 +290,46 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
222
290
223
291
args = parser .parse_args ()
224
292
225
- # Test dataset initialization
226
- print (f"\n Testing dataset with root directory: { args .root_dir } " )
227
- dataset = MarkdownPDFDocumentDataset (args .root_dir , target_longest_image_dim = 1024 , front_matter_class = StandardFrontMatter , image_transform = None )
293
+ # Test base dataset without any pipeline steps
294
+ print (f"\n === Testing base dataset without pipeline steps ===" )
295
+ base_dataset = BaseMarkdownPDFDataset (args .root_dir )
296
+ print (f"Dataset length: { len (base_dataset )} " )
297
+
298
+ if len (base_dataset ) > 0 :
299
+ print ("\n First sample (no pipeline):" )
300
+ sample = base_dataset [0 ]
301
+ print (f" Keys: { list (sample .keys ())} " )
302
+ print (f" Markdown: { sample ['markdown_path' ].name } " )
303
+ print (f" PDF: { sample ['pdf_path' ].name } " )
304
+
305
+ # Test with individual pipeline steps
306
+ print (f"\n === Testing with individual pipeline steps ===" )
307
+ pipeline_dataset = BaseMarkdownPDFDataset (
308
+ args .root_dir ,
309
+ pipeline_steps = [
310
+ FrontMatterParser (StandardFrontMatter ),
311
+ PDFRenderer (target_longest_image_dim = 1024 )
312
+ ]
313
+ )
314
+
315
+ if len (pipeline_dataset ) > 0 :
316
+ print ("\n First sample (with pipeline):" )
317
+ sample = pipeline_dataset [0 ]
318
+ print (f" Keys: { list (sample .keys ())} " )
319
+ print (f" Front Matter: { sample ['front_matter' ]} " )
320
+ print (f" Image size: { sample ['image' ].size } " )
321
+ print (f" Text preview: { sample ['text' ][:100 ]} ..." )
322
+
323
+ # Test the convenience dataset class
324
+ print (f"\n === Testing MarkdownPDFDocumentDataset (convenience class) ===" )
325
+ dataset = MarkdownPDFDocumentDataset (
326
+ args .root_dir ,
327
+ target_longest_image_dim = 1024 ,
328
+ front_matter_class = StandardFrontMatter ,
329
+ image_transform = None
330
+ )
228
331
229
- print (f"\n Dataset length: { len (dataset )} " )
332
+ print (f"Dataset length: { len (dataset )} " )
230
333
231
334
if len (dataset ) > 0 :
232
335
# Show first few samples
@@ -242,4 +345,4 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
242
345
print (f"Image size: { first_sample ['image' ].size } " )
243
346
print (f"PDF Path: { first_sample ['pdf_path' ]} " )
244
347
print (f"Front Matter: { first_sample ['front_matter' ]} " )
245
- print (f"Text: { first_sample ['text' ]} ..." )
348
+ print (f"Text: { first_sample ['text' ][: 200 ] } ..." )
0 commit comments