Skip to content

Commit d17bef8

Browse files
committed
Working on a more pipeliney thing
1 parent d0df380 commit d17bef8

File tree

1 file changed

+200
-97
lines changed

1 file changed

+200
-97
lines changed

olmocr/train/dataloader.py

Lines changed: 200 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from os import PathLike
22
from pathlib import Path
3-
from typing import Dict, Any, Optional, Type
3+
from typing import Dict, Any, Optional, Type, List, Callable
44
import base64
55
from io import BytesIO
66
from PIL import Image
77
from torch.utils.data import Dataset
88
from pypdf import PdfReader
99
from tqdm import tqdm
1010
from dataclasses import dataclass, fields
11+
from abc import ABC, abstractmethod
1112

1213
from olmocr.data.renderpdf import render_pdf_to_base64png
1314

@@ -35,71 +36,22 @@ def __post_init__(self):
3536
raise TypeError("is_table must be of type bool.")
3637
if not isinstance(self.is_diagram, bool):
3738
raise TypeError("is_diagram must be of type bool.")
38-
3939

40-
class MarkdownPDFDocumentDataset(Dataset):
41-
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
42-
"""
43-
Initialize the dataset by finding all markdown files with corresponding PDFs.
44-
45-
Args:
46-
root_dir: Path to the root folder containing processed markdown and PDF files
47-
target_longest_image_dim: Target dimension for the longest side of the image
48-
image_transform: Optional transform to apply to the PDF images
49-
front_matter_class: Optional dataclass type to validate front matter against
50-
"""
51-
self.root_dir = Path(root_dir)
52-
self.image_transform = image_transform
53-
self.target_longest_image_dim = target_longest_image_dim
40+
41+
class PipelineStep(ABC):
42+
"""Abstract base class for pipeline steps."""
43+
44+
@abstractmethod
45+
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
46+
"""Process a sample and return the modified sample."""
47+
pass
48+
49+
50+
class FrontMatterParser(PipelineStep):
51+
"""Pipeline step that parses front matter from markdown content."""
52+
53+
def __init__(self, front_matter_class: Optional[Type] = None):
5454
self.front_matter_class = front_matter_class
55-
self.samples = []
56-
57-
# Find all markdown files recursively
58-
print(f"Scanning for markdown files in {self.root_dir}...")
59-
md_files = list(self.root_dir.rglob("*.md"))
60-
61-
# Verify each markdown file has a corresponding PDF
62-
valid_count = 0
63-
invalid_pdfs = []
64-
65-
print(f"Validating {len(md_files)} markdown-PDF pairs...")
66-
for md_path in tqdm(md_files, desc="Validating PDFs"):
67-
# Look for PDF with same stem (filename without extension)
68-
pdf_path = md_path.with_suffix('.pdf')
69-
70-
if pdf_path.exists() or pdf_path.is_symlink():
71-
# Resolve symlink if it is one
72-
if pdf_path.is_symlink():
73-
pdf_path = pdf_path.resolve()
74-
75-
# Verify the resolved path exists
76-
if pdf_path.exists():
77-
# Validate PDF - check it loads and has exactly one page
78-
try:
79-
reader = PdfReader(str(pdf_path))
80-
num_pages = len(reader.pages)
81-
82-
if num_pages != 1:
83-
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
84-
continue
85-
86-
self.samples.append({
87-
'markdown_path': md_path,
88-
'pdf_path': pdf_path
89-
})
90-
valid_count += 1
91-
92-
except Exception as e:
93-
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
94-
95-
print(f"Found {valid_count} valid markdown-PDF pairs")
96-
97-
if invalid_pdfs:
98-
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
99-
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
100-
print(f" - {pdf_path.name}: {reason}")
101-
if len(invalid_pdfs) > 5:
102-
print(f" ... and {len(invalid_pdfs) - 5} more")
10355

10456
def _extract_front_matter_and_text(self, markdown_content: str) -> tuple[str, str]:
10557
"""Extract raw front matter string and text from markdown content."""
@@ -165,6 +117,119 @@ def _parse_front_matter(self, front_matter_dict: Dict[str, Any]) -> Any:
165117

166118
return self.front_matter_class(**kwargs)
167119

120+
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
121+
"""Parse front matter from markdown content."""
122+
# Read markdown content if not already loaded
123+
if 'markdown_content' not in sample:
124+
sample['markdown_content'] = sample['markdown_path'].read_text(encoding='utf-8')
125+
126+
# Extract and parse front matter
127+
front_matter_str, text = self._extract_front_matter_and_text(sample['markdown_content'])
128+
front_matter = self._parse_front_matter_string(front_matter_str)
129+
130+
# Parse front matter to dataclass if specified
131+
try:
132+
parsed_front_matter = self._parse_front_matter(front_matter)
133+
except Exception as e:
134+
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
135+
136+
# Update sample
137+
sample['text'] = text
138+
sample['front_matter'] = parsed_front_matter
139+
140+
return sample
141+
142+
143+
class PDFRenderer(PipelineStep):
144+
"""Pipeline step that renders PDF to image."""
145+
146+
def __init__(self, target_longest_image_dim: int, image_transform: Optional[Callable] = None):
147+
self.target_longest_image_dim = target_longest_image_dim
148+
self.image_transform = image_transform
149+
150+
def process(self, sample: Dict[str, Any]) -> Dict[str, Any]:
151+
"""Render PDF to image."""
152+
# Render PDF to image
153+
base64_png = render_pdf_to_base64png(
154+
str(sample['pdf_path']),
155+
page_num=1,
156+
target_longest_image_dim=self.target_longest_image_dim
157+
)
158+
png_bytes = base64.b64decode(base64_png)
159+
image = Image.open(BytesIO(png_bytes))
160+
161+
# Apply transform if provided
162+
if self.image_transform:
163+
image = self.image_transform(image)
164+
165+
# Update sample
166+
sample['image'] = image
167+
168+
return sample
169+
170+
171+
class BaseMarkdownPDFDataset(Dataset):
172+
"""Base dataset class that loads and verifies markdown-PDF pairs."""
173+
174+
def __init__(self, root_dir: str | PathLike, pipeline_steps: Optional[List[PipelineStep]] = None):
175+
"""
176+
Initialize the dataset by finding all markdown files with corresponding PDFs.
177+
178+
Args:
179+
root_dir: Path to the root folder containing processed markdown and PDF files
180+
pipeline_steps: Optional list of pipeline steps to apply to each sample
181+
"""
182+
self.root_dir = Path(root_dir)
183+
self.pipeline_steps = pipeline_steps or []
184+
self.samples = []
185+
186+
# Find all markdown files recursively
187+
print(f"Scanning for markdown files in {self.root_dir}...")
188+
md_files = list(self.root_dir.rglob("*.md"))
189+
190+
# Verify each markdown file has a corresponding PDF
191+
valid_count = 0
192+
invalid_pdfs = []
193+
194+
print(f"Validating {len(md_files)} markdown-PDF pairs...")
195+
for md_path in tqdm(md_files, desc="Validating PDFs"):
196+
# Look for PDF with same stem (filename without extension)
197+
pdf_path = md_path.with_suffix('.pdf')
198+
199+
if pdf_path.exists() or pdf_path.is_symlink():
200+
# Resolve symlink if it is one
201+
if pdf_path.is_symlink():
202+
pdf_path = pdf_path.resolve()
203+
204+
# Verify the resolved path exists
205+
if pdf_path.exists():
206+
# Validate PDF - check it loads and has exactly one page
207+
try:
208+
reader = PdfReader(str(pdf_path))
209+
num_pages = len(reader.pages)
210+
211+
if num_pages != 1:
212+
invalid_pdfs.append((pdf_path, f"Expected 1 page, found {num_pages}"))
213+
continue
214+
215+
self.samples.append({
216+
'markdown_path': md_path,
217+
'pdf_path': pdf_path
218+
})
219+
valid_count += 1
220+
221+
except Exception as e:
222+
invalid_pdfs.append((pdf_path, f"Failed to load: {str(e)}"))
223+
224+
print(f"Found {valid_count} valid markdown-PDF pairs")
225+
226+
if invalid_pdfs:
227+
print(f"\nWarning: {len(invalid_pdfs)} invalid PDFs found:")
228+
for pdf_path, reason in invalid_pdfs[:5]: # Show first 5
229+
print(f" - {pdf_path.name}: {reason}")
230+
if len(invalid_pdfs) > 5:
231+
print(f" ... and {len(invalid_pdfs) - 5} more")
232+
168233
def __len__(self) -> int:
169234
return len(self.samples)
170235

@@ -173,40 +238,43 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
173238
Get a single sample from the dataset.
174239
175240
Returns:
176-
dict containing:
177-
- 'image': PIL Image of the rendered PDF page
241+
dict containing at minimum:
242+
- 'markdown_path': Path to the markdown file
178243
- 'pdf_path': Path to the PDF file
179-
- 'text': Text content without front matter
180-
- 'front_matter': Dict with parsed front matter
244+
245+
Additional fields will be added by pipeline steps.
181246
"""
182-
sample = self.samples[idx]
247+
# Start with basic sample info
248+
sample = self.samples[idx].copy()
183249

184-
# Read and parse markdown file
185-
markdown_content = sample['markdown_path'].read_text(encoding='utf-8')
186-
front_matter_str, text = self._extract_front_matter_and_text(markdown_content)
187-
front_matter = self._parse_front_matter_string(front_matter_str)
250+
# Apply pipeline steps
251+
for step in self.pipeline_steps:
252+
sample = step.process(sample)
188253

189-
# Render PDF to image
190-
base64_png = render_pdf_to_base64png(str(sample['pdf_path']), page_num=1, target_longest_image_dim=self.target_longest_image_dim)
191-
png_bytes = base64.b64decode(base64_png)
192-
image = Image.open(BytesIO(png_bytes))
193-
194-
# Apply transform if provided
195-
if self.image_transform:
196-
image = self.image_transform(image)
254+
return sample
255+
256+
257+
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
258+
"""Dataset that includes front matter parsing and PDF rendering by default."""
259+
260+
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
261+
"""
262+
Initialize the dataset with default pipeline steps.
197263
198-
# Parse front matter to dataclass if specified
199-
try:
200-
parsed_front_matter = self._parse_front_matter(front_matter)
201-
except Exception as e:
202-
raise ValueError(f"Error parsing front matter for {sample['markdown_path']}: {e}")
264+
Args:
265+
root_dir: Path to the root folder containing processed markdown and PDF files
266+
target_longest_image_dim: Target dimension for the longest side of the image
267+
image_transform: Optional transform to apply to the PDF images
268+
front_matter_class: Optional dataclass type to validate front matter against
269+
"""
270+
# Create default pipeline steps
271+
pipeline_steps = [
272+
FrontMatterParser(front_matter_class),
273+
PDFRenderer(target_longest_image_dim, image_transform)
274+
]
203275

204-
return {
205-
'image': image,
206-
'pdf_path': str(sample['pdf_path']),
207-
'text': text,
208-
'front_matter': parsed_front_matter
209-
}
276+
# Initialize base class with pipeline
277+
super().__init__(root_dir, pipeline_steps)
210278

211279

212280
if __name__ == "__main__":
@@ -222,11 +290,46 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
222290

223291
args = parser.parse_args()
224292

225-
# Test dataset initialization
226-
print(f"\nTesting dataset with root directory: {args.root_dir}")
227-
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None)
293+
# Test base dataset without any pipeline steps
294+
print(f"\n=== Testing base dataset without pipeline steps ===")
295+
base_dataset = BaseMarkdownPDFDataset(args.root_dir)
296+
print(f"Dataset length: {len(base_dataset)}")
297+
298+
if len(base_dataset) > 0:
299+
print("\nFirst sample (no pipeline):")
300+
sample = base_dataset[0]
301+
print(f" Keys: {list(sample.keys())}")
302+
print(f" Markdown: {sample['markdown_path'].name}")
303+
print(f" PDF: {sample['pdf_path'].name}")
304+
305+
# Test with individual pipeline steps
306+
print(f"\n=== Testing with individual pipeline steps ===")
307+
pipeline_dataset = BaseMarkdownPDFDataset(
308+
args.root_dir,
309+
pipeline_steps=[
310+
FrontMatterParser(StandardFrontMatter),
311+
PDFRenderer(target_longest_image_dim=1024)
312+
]
313+
)
314+
315+
if len(pipeline_dataset) > 0:
316+
print("\nFirst sample (with pipeline):")
317+
sample = pipeline_dataset[0]
318+
print(f" Keys: {list(sample.keys())}")
319+
print(f" Front Matter: {sample['front_matter']}")
320+
print(f" Image size: {sample['image'].size}")
321+
print(f" Text preview: {sample['text'][:100]}...")
322+
323+
# Test the convenience dataset class
324+
print(f"\n=== Testing MarkdownPDFDocumentDataset (convenience class) ===")
325+
dataset = MarkdownPDFDocumentDataset(
326+
args.root_dir,
327+
target_longest_image_dim=1024,
328+
front_matter_class=StandardFrontMatter,
329+
image_transform=None
330+
)
228331

229-
print(f"\nDataset length: {len(dataset)}")
332+
print(f"Dataset length: {len(dataset)}")
230333

231334
if len(dataset) > 0:
232335
# Show first few samples
@@ -242,4 +345,4 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
242345
print(f"Image size: {first_sample['image'].size}")
243346
print(f"PDF Path: {first_sample['pdf_path']}")
244347
print(f"Front Matter: {first_sample['front_matter']}")
245-
print(f"Text: {first_sample['text']}...")
348+
print(f"Text: {first_sample['text'][:200]}...")

0 commit comments

Comments
 (0)