Skip to content

Commit 6a360fa

Browse files
committed
Cleanup
1 parent d17bef8 commit 6a360fa

File tree

1 file changed

+19
-31
lines changed

1 file changed

+19
-31
lines changed

olmocr/train/dataloader.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,9 @@
1111
from abc import ABC, abstractmethod
1212

1313
from olmocr.data.renderpdf import render_pdf_to_base64png
14+
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
1415

15-
@dataclass(frozen=True)
16-
class StandardFrontMatter:
17-
primary_language: Optional[str]
18-
is_rotation_valid: bool
19-
rotation_correction: int
20-
is_table: bool
21-
is_diagram: bool
22-
23-
def __post_init__(self):
24-
# Validate rotation_correction is one of the allowed values
25-
if self.rotation_correction not in {0, 90, 180, 270}:
26-
raise ValueError("rotation_correction must be one of [0, 90, 180, 270].")
27-
28-
# Type checks
29-
if not isinstance(self.primary_language, (str, type(None))):
30-
raise TypeError("primary_language must be of type Optional[str].")
31-
if not isinstance(self.is_rotation_valid, bool):
32-
raise TypeError("is_rotation_valid must be of type bool.")
33-
if not isinstance(self.rotation_correction, int):
34-
raise TypeError("rotation_correction must be of type int.")
35-
if not isinstance(self.is_table, bool):
36-
raise TypeError("is_table must be of type bool.")
37-
if not isinstance(self.is_diagram, bool):
38-
raise TypeError("is_diagram must be of type bool.")
16+
# Import PageResponse from prompts.py instead of defining StandardFrontMatter here
3917

4018

4119
class PipelineStep(ABC):
@@ -84,7 +62,7 @@ def _parse_front_matter_string(self, front_matter_str: str) -> Dict[str, Any]:
8462

8563
return front_matter
8664

87-
def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
65+
def _parse_front_matter(self, front_matter_dict: Dict[str, Any], text: str) -> Any:
8866
"""Parse front matter dictionary into dataclass instance if front_matter_class is specified."""
8967
if not self.front_matter_class:
9068
return front_matter_dict
@@ -95,6 +73,11 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
9573
# Validate and convert values
9674
kwargs = {}
9775
for field_name, field_type in field_info.items():
76+
# Special handling for natural_text field in PageResponse
77+
if field_name == 'natural_text' and self.front_matter_class == PageResponse:
78+
kwargs[field_name] = text if text else None
79+
continue
80+
9881
if field_name not in front_matter_dict:
9982
raise ValueError(f"Missing required field '{field_name}' in front matter")
10083

@@ -110,8 +93,11 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
11093
else:
11194
kwargs[field_name] = value
11295

113-
# Check for extra fields
114-
extra_fields = set(front_matter_dict.keys()) - set(field_info.keys())
96+
# Check for extra fields (excluding natural_text if it's PageResponse)
97+
expected_fields = set(field_info.keys())
98+
if self.front_matter_class == PageResponse:
99+
expected_fields.discard('natural_text')
100+
extra_fields = set(front_matter_dict.keys()) - expected_fields
115101
if extra_fields:
116102
raise ValueError(f"Unexpected fields in front matter: {extra_fields}")
117103

@@ -129,7 +115,7 @@ def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
129115

130116
# Parse front matter to dataclass if specified
131117
try:
132-
parsed_front_matter = self._parse_front_matter(front_matter)
118+
parsed_front_matter = self._parse_front_matter(front_matter, text)
133119
except Exception as e:
134120
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
135121

@@ -307,7 +293,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
307293
pipeline_dataset = BaseMarkdownPDFDataset(
308294
args.root_dir,
309295
pipeline_steps=[
310-
FrontMatterParser(StandardFrontMatter),
296+
FrontMatterParser(PageResponse),
311297
PDFRenderer(target_longest_image_dim=1024)
312298
]
313299
)
@@ -325,7 +311,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
325311
dataset = MarkdownPDFDocumentDataset(
326312
args.root_dir,
327313
target_longest_image_dim=1024,
328-
front_matter_class=StandardFrontMatter,
314+
front_matter_class=PageResponse,
329315
image_transform=None
330316
)
331317

@@ -345,4 +331,6 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
345331
print(f"Image size: {first_sample['image'].size}")
346332
print(f"PDF Path: {first_sample['pdf_path']}")
347333
print(f"Front Matter: {first_sample['front_matter']}")
348-
print(f"Text: {first_sample['text'][:200]}...")
334+
print(f"Text (first 200 chars): {first_sample['text'][:200]}...")
335+
if hasattr(first_sample['front_matter'], 'natural_text'):
336+
print(f"Natural Text from PageResponse: {first_sample['front_matter'].natural_text[:200] if first_sample['front_matter'].natural_text else 'None'}...")

0 commit comments

Comments
 (0)