1
1
from os import PathLike
2
2
from pathlib import Path
3
- from typing import Dict , Any , Optional , Type , List , Callable
3
+ from typing import Dict , Any , Optional , Type , List , Callable , TypeAlias
4
4
import base64
5
5
from io import BytesIO
6
+ from functools import reduce
7
+ import logging
8
+ import yaml
6
9
from PIL import Image
7
10
from torch .utils .data import Dataset
8
11
from pypdf import PdfReader
13
16
from olmocr .data .renderpdf import render_pdf_to_base64png
14
17
from olmocr .prompts .prompts import PageResponse , build_finetuning_prompt
15
18
16
- # Import PageResponse from prompts.py instead of defining StandardFrontMatter here
19
+ # Type alias for samples
20
+ Sample : TypeAlias = Dict [str , Any ]
17
21
22
+ # Configure logging
23
+ logger = logging .getLogger (__name__ )
18
24
25
+
26
+ @dataclass (frozen = True , slots = True )
19
27
class PipelineStep (ABC ):
20
28
"""Abstract base class for pipeline steps."""
21
29
22
30
@abstractmethod
23
- def process (self , sample : Dict [ str , Any ] ) -> Dict [ str , Any ] :
31
+ def __call__ (self , sample : Sample ) -> Sample :
24
32
"""Process a sample and return the modified sample."""
25
- pass
33
+ ...
26
34
27
35
28
- class FrontMatterParser (PipelineStep ):
29
- """Pipeline step that parses front matter from markdown content."""
30
-
31
- def __init__ (self , front_matter_class : Optional [Type ] = None ):
32
- self .front_matter_class = front_matter_class
36
+ class BaseMarkdownPDFDataset (Dataset ):
37
+ """Base dataset class that loads and verifies markdown-PDF pairs."""
33
38
34
- def _extract_front_matter_and_text (self , markdown_content : str ) -> tuple [str , str ]:
35
- """Extract raw front matter string and text from markdown content."""
36
- if markdown_content .startswith ('---\n ' ):
37
- parts = markdown_content .split ('---\n ' , 2 )
38
- if len (parts ) >= 3 :
39
- return parts [1 ].strip (), parts [2 ].strip ()
39
+ def __init__ (self , root_dir : str | PathLike , pipeline_steps : Optional [List [PipelineStep ]] = None ):
40
+ """
41
+ Initialize the dataset by finding all markdown files with corresponding PDFs.
40
42
41
- return '' , markdown_content
42
-
43
- def _parse_front_matter_string (self , front_matter_str : str ) -> Dict [str , Any ]:
44
- """Parse front matter string into a dictionary."""
45
- front_matter = {}
43
+ Args:
44
+ root_dir: Path to the root folder containing processed markdown and PDF files
45
+ pipeline_steps: Optional list of pipeline steps to apply to each sample
46
+ """
47
+ self .root_dir = Path (root_dir )
48
+ self .pipeline_steps = pipeline_steps or []
49
+ self .samples = []
50
+
51
+ # Find all markdown files recursively
52
+ logger .info (f"Scanning for markdown files in { self .root_dir } ..." )
53
+ md_files = list (self .root_dir .rglob ("*.md" ))
46
54
47
- if not front_matter_str :
48
- return front_matter
55
+ # Verify each markdown file has a corresponding PDF
56
+ valid_count = 0
57
+ invalid_pdfs = []
58
+
59
+ logger .info (f"Validating { len (md_files )} markdown-PDF pairs..." )
60
+ for md_path in tqdm (md_files , desc = "Validating PDFs" ):
61
+ # Look for PDF with same stem (filename without extension)
62
+ pdf_path = md_path .with_suffix ('.pdf' )
49
63
50
- for line in front_matter_str .split ('\n ' ):
51
- if ': ' in line :
52
- key , value = line .split (': ' , 1 )
53
- # Simple type inference
54
- if value .lower () == 'true' :
55
- front_matter [key ] = True
56
- elif value .lower () == 'false' :
57
- front_matter [key ] = False
58
- elif value .isdigit ():
59
- front_matter [key ] = int (value )
60
- else :
61
- front_matter [key ] = value
64
+ if pdf_path .exists () or pdf_path .is_symlink ():
65
+ # Resolve symlink if it is one
66
+ if pdf_path .is_symlink ():
67
+ pdf_path = pdf_path .resolve ()
68
+
69
+ # Verify the resolved path exists
70
+ if pdf_path .exists ():
71
+ # Validate PDF - check it loads and has exactly one page
72
+ try :
73
+ reader = PdfReader (str (pdf_path ))
74
+ num_pages = len (reader .pages )
75
+
76
+ if num_pages != 1 :
77
+ invalid_pdfs .append ((pdf_path , f"Expected 1 page, found { num_pages } " ))
78
+ continue
79
+
80
+ self .samples .append ({
81
+ 'markdown_path' : md_path ,
82
+ 'pdf_path' : pdf_path
83
+ })
84
+ valid_count += 1
85
+
86
+ except Exception as e :
87
+ invalid_pdfs .append ((pdf_path , f"Failed to load: { str (e )} " ))
88
+
89
+ logger .info (f"Found { valid_count } valid markdown-PDF pairs" )
62
90
63
- return front_matter
91
+ if invalid_pdfs :
92
+ logger .warning (f"{ len (invalid_pdfs )} invalid PDFs found:" )
93
+ for pdf_path , reason in invalid_pdfs [:5 ]: # Show first 5
94
+ logger .warning (f" - { pdf_path .name } : { reason } " )
95
+ if len (invalid_pdfs ) > 5 :
96
+ logger .warning (f" ... and { len (invalid_pdfs ) - 5 } more" )
97
+
98
+ def __len__ (self ) -> int :
99
+ return len (self .samples )
100
+
101
+ def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
102
+ """
103
+ Get a single sample from the dataset.
104
+
105
+ Returns:
106
+ dict containing at minimum:
107
+ - 'markdown_path': Path to the markdown file
108
+ - 'pdf_path': Path to the PDF file
109
+
110
+ Additional fields will be added by pipeline steps.
111
+ """
112
+ # Start with basic sample info
113
+ sample = self .samples [idx ].copy ()
114
+
115
+ # Apply pipeline steps using reduce
116
+ return reduce (lambda s , f : f (s ), self .pipeline_steps , sample )
117
+
118
+
119
+ @dataclass (frozen = True , slots = True )
120
+ class FrontMatterParser (PipelineStep ):
121
+ """Pipeline step that parses YAML front matter from markdown content."""
122
+ front_matter_class : Optional [Type ] = None
123
+
124
+ def _extract_front_matter_and_text (self , markdown_content : str ) -> tuple [Dict [str , Any ], str ]:
125
+ """Extract YAML front matter and text from markdown content."""
126
+ if markdown_content .startswith ('---\n ' ):
127
+ try :
128
+ # Find the closing --- delimiter
129
+ end_index = markdown_content .find ('\n ---\n ' , 4 )
130
+ if end_index != - 1 :
131
+ front_matter_str = markdown_content [4 :end_index ]
132
+ text = markdown_content [end_index + 5 :].strip ()
133
+
134
+ # Parse YAML
135
+ front_matter = yaml .safe_load (front_matter_str ) or {}
136
+ return front_matter , text
137
+ except yaml .YAMLError as e :
138
+ logger .warning (f"Failed to parse YAML front matter: { e } " )
139
+
140
+ return {}, markdown_content .strip ()
64
141
65
142
def _parse_front_matter (self , front_matter_dict : Dict [str , Any ], text : str ) -> Any :
66
143
"""Parse front matter dictionary into dataclass instance if front_matter_class is specified."""
@@ -103,15 +180,14 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any], text: str) -> A
103
180
104
181
return self .front_matter_class (** kwargs )
105
182
106
- def process (self , sample : Dict [ str , Any ] ) -> Dict [ str , Any ] :
183
+ def __call__ (self , sample : Sample ) -> Sample :
107
184
"""Parse front matter from markdown content."""
108
185
# Read markdown content if not already loaded
109
186
if 'markdown_content' not in sample :
110
187
sample ['markdown_content' ] = sample ['markdown_path' ].read_text (encoding = 'utf-8' )
111
188
112
189
# Extract and parse front matter
113
- front_matter_str , text = self ._extract_front_matter_and_text (sample ['markdown_content' ])
114
- front_matter = self ._parse_front_matter_string (front_matter_str )
190
+ front_matter , text = self ._extract_front_matter_and_text (sample ['markdown_content' ])
115
191
116
192
# Parse front matter to dataclass if specified
117
193
try :
@@ -126,14 +202,13 @@ def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
126
202
return sample
127
203
128
204
205
+ @dataclass (frozen = True , slots = True )
129
206
class PDFRenderer (PipelineStep ):
130
207
"""Pipeline step that renders PDF to image."""
208
+ target_longest_image_dim : int
209
+ image_transform : Optional [Callable ] = None
131
210
132
- def __init__ (self , target_longest_image_dim : int , image_transform : Optional [Callable ] = None ):
133
- self .target_longest_image_dim = target_longest_image_dim
134
- self .image_transform = image_transform
135
-
136
- def process (self , sample : Dict [str , Any ]) -> Dict [str , Any ]:
211
+ def __call__ (self , sample : Sample ) -> Sample :
137
212
"""Render PDF to image."""
138
213
# Render PDF to image
139
214
base64_png = render_pdf_to_base64png (
@@ -152,91 +227,17 @@ def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
152
227
sample ['image' ] = image
153
228
154
229
return sample
155
-
156
-
157
- class BaseMarkdownPDFDataset (Dataset ):
158
- """Base dataset class that loads and verifies markdown-PDF pairs."""
159
-
160
- def __init__ (self , root_dir : str | PathLike , pipeline_steps : Optional [List [PipelineStep ]] = None ):
161
- """
162
- Initialize the dataset by finding all markdown files with corresponding PDFs.
163
-
164
- Args:
165
- root_dir: Path to the root folder containing processed markdown and PDF files
166
- pipeline_steps: Optional list of pipeline steps to apply to each sample
167
- """
168
- self .root_dir = Path (root_dir )
169
- self .pipeline_steps = pipeline_steps or []
170
- self .samples = []
171
-
172
- # Find all markdown files recursively
173
- print (f"Scanning for markdown files in { self .root_dir } ..." )
174
- md_files = list (self .root_dir .rglob ("*.md" ))
175
-
176
- # Verify each markdown file has a corresponding PDF
177
- valid_count = 0
178
- invalid_pdfs = []
179
-
180
- print (f"Validating { len (md_files )} markdown-PDF pairs..." )
181
- for md_path in tqdm (md_files , desc = "Validating PDFs" ):
182
- # Look for PDF with same stem (filename without extension)
183
- pdf_path = md_path .with_suffix ('.pdf' )
184
-
185
- if pdf_path .exists () or pdf_path .is_symlink ():
186
- # Resolve symlink if it is one
187
- if pdf_path .is_symlink ():
188
- pdf_path = pdf_path .resolve ()
189
-
190
- # Verify the resolved path exists
191
- if pdf_path .exists ():
192
- # Validate PDF - check it loads and has exactly one page
193
- try :
194
- reader = PdfReader (str (pdf_path ))
195
- num_pages = len (reader .pages )
196
-
197
- if num_pages != 1 :
198
- invalid_pdfs .append ((pdf_path , f"Expected 1 page, found { num_pages } " ))
199
- continue
200
-
201
- self .samples .append ({
202
- 'markdown_path' : md_path ,
203
- 'pdf_path' : pdf_path
204
- })
205
- valid_count += 1
206
-
207
- except Exception as e :
208
- invalid_pdfs .append ((pdf_path , f"Failed to load: { str (e )} " ))
209
-
210
- print (f"Found { valid_count } valid markdown-PDF pairs" )
211
-
212
- if invalid_pdfs :
213
- print (f"\n Warning: { len (invalid_pdfs )} invalid PDFs found:" )
214
- for pdf_path , reason in invalid_pdfs [:5 ]: # Show first 5
215
- print (f" - { pdf_path .name } : { reason } " )
216
- if len (invalid_pdfs ) > 5 :
217
- print (f" ... and { len (invalid_pdfs ) - 5 } more" )
218
230
219
- def __len__ (self ) -> int :
220
- return len (self .samples )
231
+
232
+ @dataclass (frozen = True , slots = True )
233
+ class PromptBuilder (PipelineStep ):
234
+ """Pipeline step that builds prompts using the finetuning prompt template."""
235
+ base_text_field : str = 'text'
221
236
222
- def __getitem__ (self , idx : int ) -> Dict [str , Any ]:
223
- """
224
- Get a single sample from the dataset.
225
-
226
- Returns:
227
- dict containing at minimum:
228
- - 'markdown_path': Path to the markdown file
229
- - 'pdf_path': Path to the PDF file
230
-
231
- Additional fields will be added by pipeline steps.
232
- """
233
- # Start with basic sample info
234
- sample = self .samples [idx ].copy ()
235
-
236
- # Apply pipeline steps
237
- for step in self .pipeline_steps :
238
- sample = step .process (sample )
239
-
237
+ def __call__ (self , sample : Sample ) -> Sample :
238
+ """Build prompt from base text."""
239
+ base_text = sample .get (self .base_text_field , '' )
240
+ sample ['prompt' ] = build_finetuning_prompt (base_text )
240
241
return sample
241
242
242
243
@@ -266,6 +267,9 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
266
267
if __name__ == "__main__" :
267
268
import argparse
268
269
270
+ # Set up logging for testing
271
+ logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' )
272
+
269
273
parser = argparse .ArgumentParser (description = "Test MarkdownPDFDocumentDataset" )
270
274
parser .add_argument (
271
275
"--root-dir" ,
@@ -293,8 +297,9 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
293
297
pipeline_dataset = BaseMarkdownPDFDataset (
294
298
args .root_dir ,
295
299
pipeline_steps = [
296
- FrontMatterParser (PageResponse ),
297
- PDFRenderer (target_longest_image_dim = 1024 )
300
+ FrontMatterParser (front_matter_class = PageResponse ),
301
+ PDFRenderer (target_longest_image_dim = 1024 ),
302
+ PromptBuilder ()
298
303
]
299
304
)
300
305
@@ -305,6 +310,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
305
310
print (f" Front Matter: { sample ['front_matter' ]} " )
306
311
print (f" Image size: { sample ['image' ].size } " )
307
312
print (f" Text preview: { sample ['text' ][:100 ]} ..." )
313
+ print (f" Prompt preview: { sample .get ('prompt' , 'No prompt' )[:200 ]} ..." )
308
314
309
315
# Test the convenience dataset class
310
316
print (f"\n === Testing MarkdownPDFDocumentDataset (convenience class) ===" )
0 commit comments