Skip to content

Commit d0df380

Browse files
committed
Cleaning data loader
1 parent 5bbc1ff commit d0df380

File tree

1 file changed

+19
-37
lines changed

1 file changed

+19
-37
lines changed

olmocr/train/dataloader.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ class StandardFrontMatter:
1919
is_table: bool
2020
is_diagram: bool
2121

22+
def __post_init__(self):
23+
# Validate rotation_correction is one of the allowed values
24+
if self.rotation_correction not in {0, 90, 180, 270}:
25+
raise ValueError("rotation_correction must be one of [0, 90, 180, 270].")
26+
27+
# Type checks
28+
if not isinstance(self.primary_language, (str, type(None))):
29+
raise TypeError("primary_language must be of type Optional[str].")
30+
if not isinstance(self.is_rotation_valid, bool):
31+
raise TypeError("is_rotation_valid must be of type bool.")
32+
if not isinstance(self.rotation_correction, int):
33+
raise TypeError("rotation_correction must be of type int.")
34+
if not isinstance(self.is_table, bool):
35+
raise TypeError("is_table must be of type bool.")
36+
if not isinstance(self.is_diagram, bool):
37+
raise TypeError("is_diagram must be of type bool.")
38+
2239

2340
class MarkdownPDFDocumentDataset(Dataset):
2441
def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, image_transform=None, front_matter_class=None):
@@ -207,7 +224,7 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
207224

208225
# Test dataset initialization
209226
print(f"\nTesting dataset with root directory: {args.root_dir}")
210-
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=None)
227+
dataset = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, front_matter_class=StandardFrontMatter, image_transform=None)
211228

212229
print(f"\nDataset length: {len(dataset)}")
213230

@@ -225,39 +242,4 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
225242
print(f"Image size: {first_sample['image'].size}")
226243
print(f"PDF Path: {first_sample['pdf_path']}")
227244
print(f"Front Matter: {first_sample['front_matter']}")
228-
print(f"Text preview (first 200 chars): {first_sample['text'][:200]}...")
229-
230-
# Test with transforms
231-
print("\nTesting with torchvision transforms:")
232-
import torchvision.transforms as transforms
233-
234-
transform = transforms.Compose([
235-
transforms.Resize((1024, 1024)),
236-
transforms.ToTensor(),
237-
])
238-
239-
dataset_with_transform = MarkdownPDFDocumentDataset(args.root_dir, target_longest_image_dim=1024, image_transform=transform)
240-
transformed_sample = dataset_with_transform[0]
241-
print(f"Transformed image type: {type(transformed_sample['image'])}")
242-
print(f"Transformed image shape: {transformed_sample['image'].shape}")
243-
244-
# Test with front matter validation
245-
print("\n\nTesting with front matter validation:")
246-
dataset_with_validation = MarkdownPDFDocumentDataset(
247-
args.root_dir,
248-
target_longest_image_dim=1024,
249-
front_matter_class=StandardFrontMatter
250-
)
251-
252-
validated_sample = dataset_with_validation[0]
253-
print(f"Front matter type: {type(validated_sample['front_matter'])}")
254-
print(f"Front matter: {validated_sample['front_matter']}")
255-
256-
# Access fields directly
257-
fm = validated_sample['front_matter']
258-
print(f"\nAccessing fields:")
259-
print(f" primary_language: {fm.primary_language}")
260-
print(f" is_rotation_valid: {fm.is_rotation_valid}")
261-
print(f" rotation_correction: {fm.rotation_correction}")
262-
print(f" is_table: {fm.is_table}")
263-
print(f" is_diagram: {fm.is_diagram}")
245+
print(f"Text: {first_sample['text']}...")

0 commit comments

Comments
 (0)