Skip to content

Commit 9f50bda

Browse files
committed
More refactoring
1 parent 6a360fa commit 9f50bda

File tree

1 file changed

+134
-128
lines changed

1 file changed

+134
-128
lines changed

olmocr/train/dataloader.py

Lines changed: 134 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from os import PathLike
22
from pathlib import Path
3-
from typing import Dict, Any, Optional, Type, List, Callable
3+
from typing import Dict, Any, Optional, Type, List, Callable, TypeAlias
44
import base64
55
from io import BytesIO
6+
from functools import reduce
7+
import logging
8+
import yaml
69
from PIL import Image
710
from torch.utils.data import Dataset
811
from pypdf import PdfReader
@@ -13,54 +16,128 @@
1316
from olmocr.data.renderpdf import render_pdf_to_base64png
1417
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
1518

16-
# Import PageResponse from prompts.py instead of defining StandardFrontMatter here
19+
# Type alias for samples
20+
Sample: TypeAlias = Dict[str, Any]
1721

22+
# Configure logging
23+
logger = logging.getLogger(__name__)
1824

25+
26+
@dataclass(frozen=True, slots=True)
1927
class PipelineStep(ABC):
2028
"""Abstract base class for pipeline steps."""
2129

2230
@abstractmethod
23-
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
31+
def __call__(self, sample: Sample) -> Sample:
2432
"""Process a sample and return the modified sample."""
25-
pass
33+
...
2634

2735

28-
class FrontMatterParser(PipelineStep):
29-
"""Pipeline step that parses front matter from markdown content."""
30-
31-
def __init__(self, front_matter_class: Optional[Type] = None):
32-
self.front_matter_class = front_matter_class
36+
class BaseMarkdownPDFDataset(Dataset):
37+
"""Base dataset class that loads and verifies markdown-PDF pairs."""
3338

34-
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]:
35-
"""Extract raw front matter string and text from markdown content."""
36-
if markdown_content.startswith('---\n'):
37-
parts = markdown_content.split('---\n', 2)
38-
if len(parts) >= 3:
39-
return parts[1].strip(), parts[2].strip()
39+
def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None):
40+
"""
41+
Initialize the dataset by finding all markdown files with corresponding PDFs.
4042
41-
return '', markdown_content
42-
43-
def _parse_front_matter_string(self, front_matter_str: str) -> Dict[str, Any]:
44-
"""Parse front matter string into a dictionary."""
45-
front_matter = {}
43+
Args:
44+
root_dir: Path to the root folder containing processed markdown and PDF files
45+
pipeline_steps: Optional list of pipeline steps to apply to each sample
46+
"""
47+
self.root_dir = Path(root_dir)
48+
self.pipeline_steps = pipeline_steps or []
49+
self.samples = []
50+
51+
# Find all markdown files recursively
52+
logger.info(f"Scanning for markdown files in {self.root_dir}...")
53+
md_files = list(self.root_dir.rglob("*.md"))
4654

47-
if not front_matter_str:
48-
return front_matter
55+
# Verify each markdown file has a corresponding PDF
56+
valid_count = 0
57+
invalid_pdfs = []
58+
59+
logger.info(f"Validating {len(md_files)} markdown-PDF pairs...")
60+
for md_path in tqdm(md_files, desc="Validating PDFs"):
61+
# Look for PDF with same stem (filename without extension)
62+
pdf_path = md_path.with_suffix('.pdf')
4963

50-
for line in front_matter_str.split('\n'):
51-
if ': ' in line:
52-
key, value = line.split(': ', 1)
53-
# Simple type inference
54-
if value.lower() == 'true':
55-
front_matter[key] = True
56-
elif value.lower() == 'false':
57-
front_matter[key] = False
58-
elif value.isdigit():
59-
front_matter[key] = int(value)
60-
else:
61-
front_matter[key] = value
64+
if pdf_path.exists() or pdf_path.is_symlink():
65+
# Resolve symlink if it is one
66+
if pdf_path.is_symlink():
67+
pdf_path = pdf_path.resolve()
68+
69+
# Verify the resolved path exists
70+
if pdf_path.exists():
71+
# Validate PDF - check it loads and has exactly one page
72+
try:
73+
reader = PdfReader(str(pdf_path))
74+
num_pages = len(reader.pages)
75+
76+
if num_pages != 1:
77+
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
78+
continue
79+
80+
self.samples.append({
81+
'markdown_path': md_path,
82+
'pdf_path': pdf_path
83+
})
84+
valid_count += 1
85+
86+
except Exception as e:
87+
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
88+
89+
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
6290

63-
return front_matter
91+
if invalid_pdfs:
92+
logger.warning(f"{len(invalid_pdfs)} invalid PDFs found:")
93+
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
94+
logger.warning(f" - {pdf_path.name}: {reason}")
95+
if len(invalid_pdfs) > 5:
96+
logger.warning(f" ... and {len(invalid_pdfs) - 5} more")
97+
98+
def __len__(self) -> int:
99+
return len(self.samples)
100+
101+
def __getitem__(self, idx: int) -> Dict[str, Any]:
102+
"""
103+
Get a single sample from the dataset.
104+
105+
Returns:
106+
dict containing at minimum:
107+
- 'markdown_path': Path to the markdown file
108+
- 'pdf_path': Path to the PDF file
109+
110+
Additional fields will be added by pipeline steps.
111+
"""
112+
# Start with basic sample info
113+
sample = self.samples[idx].copy()
114+
115+
# Apply pipeline steps using reduce
116+
return reduce(lambda s, f: f(s), self.pipeline_steps, sample)
117+
118+
119+
@dataclass(frozen=True, slots=True)
120+
class FrontMatterParser(PipelineStep):
121+
"""Pipeline step that parses YAML front matter from markdown content."""
122+
front_matter_class: Optional[Type] = None
123+
124+
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[Dict[str, Any], str]:
125+
"""Extract YAML front matter and text from markdown content."""
126+
if markdown_content.startswith('---\n'):
127+
try:
128+
# Find the closing --- delimiter
129+
end_index = markdown_content.find('\n---\n', 4)
130+
if end_index != -1:
131+
front_matter_str = markdown_content[4:end_index]
132+
text = markdown_content[end_index + 5:].strip()
133+
134+
# Parse YAML
135+
front_matter = yaml.safe_load(front_matter_str) or {}
136+
return front_matter, text
137+
except yaml.YAMLError as e:
138+
logger.warning(f"Failed to parse YAML front matter: {e}")
139+
140+
return {}, markdown_content.strip()
64141

65142
def _parse_front_matter(self, front_matter_dict: Dict[str, Any], text: str) -> Any:
66143
"""Parse front matter dictionary into dataclass instance if front_matter_class is specified."""
@@ -103,15 +180,14 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any], text: str) -> A
103180

104181
return self.front_matter_class(**kwargs)
105182

106-
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
183+
def __call__(self, sample: Sample) -> Sample:
107184
"""Parse front matter from markdown content."""
108185
# Read markdown content if not already loaded
109186
if 'markdown_content' not in sample:
110187
sample['markdown_content'] = sample['markdown_path'].read_text(encoding='utf-8')
111188

112189
# Extract and parse front matter
113-
front_matter_str, text = self._extract_front_matter_and_text(sample['markdown_content'])
114-
front_matter = self._parse_front_matter_string(front_matter_str)
190+
front_matter, text = self._extract_front_matter_and_text(sample['markdown_content'])
115191

116192
# Parse front matter to dataclass if specified
117193
try:
@@ -126,14 +202,13 @@ def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
126202
return sample
127203

128204

205+
@dataclass(frozen=True, slots=True)
129206
class PDFRenderer(PipelineStep):
130207
"""Pipeline step that renders PDF to image."""
208+
target_longest_image_dim: int
209+
image_transform: Optional[Callable] = None
131210

132-
def __init__(self, target_longest_image_dim: int, image_transform: Optional[Callable] = None):
133-
self.target_longest_image_dim = target_longest_image_dim
134-
self.image_transform = image_transform
135-
136-
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
211+
def __call__(self, sample: Sample) -> Sample:
137212
"""Render PDF to image."""
138213
# Render PDF to image
139214
base64_png = render_pdf_to_base64png(
@@ -152,91 +227,17 @@ def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
152227
sample['image'] = image
153228

154229
return sample
155-
156-
157-
class BaseMarkdownPDFDataset(Dataset):
158-
"""Base dataset class that loads and verifies markdown-PDF pairs."""
159-
160-
def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None):
161-
"""
162-
Initialize the dataset by finding all markdown files with corresponding PDFs.
163-
164-
Args:
165-
root_dir: Path to the root folder containing processed markdown and PDF files
166-
pipeline_steps: Optional list of pipeline steps to apply to each sample
167-
"""
168-
self.root_dir = Path(root_dir)
169-
self.pipeline_steps = pipeline_steps or []
170-
self.samples = []
171-
172-
# Find all markdown files recursively
173-
print(f"Scanning for markdown files in {self.root_dir}...")
174-
md_files = list(self.root_dir.rglob("*.md"))
175-
176-
# Verify each markdown file has a corresponding PDF
177-
valid_count = 0
178-
invalid_pdfs = []
179-
180-
print(f"Validating {len(md_files)} markdown-PDF pairs...")
181-
for md_path in tqdm(md_files, desc="Validating PDFs"):
182-
# Look for PDF with same stem (filename without extension)
183-
pdf_path = md_path.with_suffix('.pdf')
184-
185-
if pdf_path.exists() or pdf_path.is_symlink():
186-
# Resolve symlink if it is one
187-
if pdf_path.is_symlink():
188-
pdf_path = pdf_path.resolve()
189-
190-
# Verify the resolved path exists
191-
if pdf_path.exists():
192-
# Validate PDF - check it loads and has exactly one page
193-
try:
194-
reader = PdfReader(str(pdf_path))
195-
num_pages = len(reader.pages)
196-
197-
if num_pages != 1:
198-
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
199-
continue
200-
201-
self.samples.append({
202-
'markdown_path': md_path,
203-
'pdf_path': pdf_path
204-
})
205-
valid_count += 1
206-
207-
except Exception as e:
208-
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
209-
210-
print(f"Found {valid_count} valid markdown-PDF pairs")
211-
212-
if invalid_pdfs:
213-
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
214-
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
215-
print(f" - {pdf_path.name}: {reason}")
216-
if len(invalid_pdfs) > 5:
217-
print(f" ... and {len(invalid_pdfs) - 5} more")
218230

219-
def __len__(self) -> int:
220-
return len(self.samples)
231+
232+
@dataclass(frozen=True, slots=True)
233+
class PromptBuilder(PipelineStep):
234+
"""Pipeline step that builds prompts using the finetuning prompt template."""
235+
base_text_field: str = 'text'
221236

222-
def __getitem__(self, idx: int) -> Dict[str, Any]:
223-
"""
224-
Get a single sample from the dataset.
225-
226-
Returns:
227-
dict containing at minimum:
228-
- 'markdown_path': Path to the markdown file
229-
- 'pdf_path': Path to the PDF file
230-
231-
Additional fields will be added by pipeline steps.
232-
"""
233-
# Start with basic sample info
234-
sample = self.samples[idx].copy()
235-
236-
# Apply pipeline steps
237-
for step in self.pipeline_steps:
238-
sample = step.process(sample)
239-
237+
def __call__(self, sample: Sample) -> Sample:
238+
"""Build prompt from base text."""
239+
base_text = sample.get(self.base_text_field, '')
240+
sample['prompt'] = build_finetuning_prompt(base_text)
240241
return sample
241242

242243

@@ -266,6 +267,9 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
266267
if __name__ == "__main__":
267268
import argparse
268269

270+
# Set up logging for testing
271+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
272+
269273
parser = argparse.ArgumentParser(description="Test MarkdownPDFDocumentDataset")
270274
parser.add_argument(
271275
"--root-dir",
@@ -293,8 +297,9 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
293297
pipeline_dataset = BaseMarkdownPDFDataset(
294298
args.root_dir,
295299
pipeline_steps=[
296-
FrontMatterParser(PageResponse),
297-
PDFRenderer(target_longest_image_dim=1024)
300+
FrontMatterParser(front_matter_class=PageResponse),
301+
PDFRenderer(target_longest_image_dim=1024),
302+
PromptBuilder()
298303
]
299304
)
300305

@@ -305,6 +310,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
305310
print(f" Front Matter: {sample['front_matter']}")
306311
print(f" Image size: {sample['image'].size}")
307312
print(f" Text preview: {sample['text'][:100]}...")
313+
print(f" Prompt preview: {sample.get('prompt', 'No prompt')[:200]}...")
308314

309315
# Test the convenience dataset class
310316
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")

0 commit comments

Comments
 (0)