11
11
from abc import ABC , abstractmethod
12
12
13
13
from olmocr .data .renderpdf import render_pdf_to_base64png
14
+ from olmocr .prompts .prompts import PageResponse , build_finetuning_prompt
14
15
15
- @dataclass (frozen = True )
16
- class StandardFrontMatter :
17
- primary_language : Optional [str ]
18
- is_rotation_valid : bool
19
- rotation_correction : int
20
- is_table : bool
21
- is_diagram : bool
22
-
23
- def __post_init__ (self ):
24
- # Validate rotation_correction is one of the allowed values
25
- if self .rotation_correction not in {0 , 90 , 180 , 270 }:
26
- raise ValueError ("rotation_correction must be one of [0, 90, 180, 270]." )
27
-
28
- # Type checks
29
- if not isinstance (self .primary_language , (str , type (None ))):
30
- raise TypeError ("primary_language must be of type Optional[str]." )
31
- if not isinstance (self .is_rotation_valid , bool ):
32
- raise TypeError ("is_rotation_valid must be of type bool." )
33
- if not isinstance (self .rotation_correction , int ):
34
- raise TypeError ("rotation_correction must be of type int." )
35
- if not isinstance (self .is_table , bool ):
36
- raise TypeError ("is_table must be of type bool." )
37
- if not isinstance (self .is_diagram , bool ):
38
- raise TypeError ("is_diagram must be of type bool." )
16
+ # Import PageResponse from prompts.py instead of defining StandardFrontMatter here
39
17
40
18
41
19
class PipelineStep (ABC ):
@@ -84,7 +62,7 @@ def _parse_front_matter_string(self, front_matter_str: str) -> Dict[str, Any]:
84
62
85
63
return front_matter
86
64
87
- def _parse_front_matter (self , front_matter_dict : Dict [str , Any ]) -> Any :
65
+ def _parse_front_matter (self , front_matter_dict : Dict [str , Any ], text : str ) -> Any :
88
66
"""Parse front matter dictionary into dataclass instance if front_matter_class is specified."""
89
67
if not self .front_matter_class :
90
68
return front_matter_dict
@@ -95,6 +73,11 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
95
73
# Validate and convert values
96
74
kwargs = {}
97
75
for field_name , field_type in field_info .items ():
76
+ # Special handling for natural_text field in PageResponse
77
+ if field_name == 'natural_text' and self .front_matter_class == PageResponse :
78
+ kwargs [field_name ] = text if text else None
79
+ continue
80
+
98
81
if field_name not in front_matter_dict :
99
82
raise ValueError (f"Missing required field '{ field_name } ' in front matter" )
100
83
@@ -110,8 +93,11 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
110
93
else :
111
94
kwargs [field_name ] = value
112
95
113
- # Check for extra fields
114
- extra_fields = set (front_matter_dict .keys ()) - set (field_info .keys ())
96
+ # Check for extra fields (excluding natural_text if it's PageResponse)
97
+ expected_fields = set (field_info .keys ())
98
+ if self .front_matter_class == PageResponse :
99
+ expected_fields .discard ('natural_text' )
100
+ extra_fields = set (front_matter_dict .keys ()) - expected_fields
115
101
if extra_fields :
116
102
raise ValueError (f"Unexpected fields in front matter: { extra_fields } " )
117
103
@@ -129,7 +115,7 @@ def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
129
115
130
116
# Parse front matter to dataclass if specified
131
117
try :
132
- parsed_front_matter = self ._parse_front_matter (front_matter )
118
+ parsed_front_matter = self ._parse_front_matter (front_matter , text )
133
119
except Exception as e :
134
120
raise ValueError (f"Error parsing front matter for { sample ['markdown_path' ]} : { e } " )
135
121
@@ -307,7 +293,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
307
293
pipeline_dataset = BaseMarkdownPDFDataset (
308
294
args .root_dir ,
309
295
pipeline_steps = [
310
- FrontMatterParser (StandardFrontMatter ),
296
+ FrontMatterParser (PageResponse ),
311
297
PDFRenderer (target_longest_image_dim = 1024 )
312
298
]
313
299
)
@@ -325,7 +311,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
325
311
dataset = MarkdownPDFDocumentDataset (
326
312
args .root_dir ,
327
313
target_longest_image_dim = 1024 ,
328
- front_matter_class = StandardFrontMatter ,
314
+ front_matter_class = PageResponse ,
329
315
image_transform = None
330
316
)
331
317
@@ -345,4 +331,6 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
345
331
print (f"Image size: { first_sample ['image' ].size } " )
346
332
print (f"PDF Path: { first_sample ['pdf_path' ]} " )
347
333
print (f"Front Matter: { first_sample ['front_matter' ]} " )
348
- print (f"Text: { first_sample ['text' ][:200 ]} ..." )
334
+ print (f"Text (first 200 chars): { first_sample ['text' ][:200 ]} ..." )
335
+ if hasattr (first_sample ['front_matter' ], 'natural_text' ):
336
+ print (f"Natural Text from PageResponse: { first_sample ['front_matter' ].natural_text [:200 ] if first_sample ['front_matter' ].natural_text else 'None' } ..." )
0 commit comments