Skip to content

Commit 0ebc35c

Browse files
committed
Basic train config loader for datasets
1 parent b93c262 commit 0ebc35c

File tree

4 files changed

+193
-48
lines changed

4 files changed

+193
-48
lines changed

olmocr/train/config.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,20 @@ class TokenizerStepConfig(PipelineStepConfig):
6464

6565

6666
@dataclass
67+
class DatasetItemConfig:
68+
"""Configuration for a single dataset item."""
69+
root_dir: str
70+
pipeline: List[Dict[str, Any]] = field(default_factory=list)
71+
72+
# Optional sampling
73+
max_samples: Optional[int] = None
74+
75+
76+
@dataclass
6777
class DatasetConfig:
6878
"""Configuration for dataset and data loading."""
69-
root_dir: str
79+
train: List[Dict[str, Any]] = field(default_factory=list)
80+
eval: List[Dict[str, Any]] = field(default_factory=list)
7081

7182
# DataLoader configuration
7283
batch_size: int = 1
@@ -76,32 +87,12 @@ class DatasetConfig:
7687
pin_memory: bool = True
7788
prefetch_factor: int = 2
7889

79-
# Pipeline steps configuration
80-
pipeline_steps: List[Dict[str, Any]] = field(default_factory=lambda: [
81-
{"name": "FrontMatterParser", "use_page_response_class": True},
82-
{"name": "PDFRenderer", "target_longest_image_dim": 1024},
83-
{"name": "StaticLengthDocumentAnchoring", "target_anchor_text_len": 6000},
84-
{"name": "FinetuningPrompt"},
85-
{"name": "FrontMatterOutputFormat"},
86-
{"name": "InstructUserMessages"},
87-
{"name": "Tokenizer", "masking_index": -100, "end_of_message_token": "<|im_end|>"}
88-
])
89-
90-
# Optional dataset sampling
91-
max_samples: Optional[int] = None
92-
validation_split: float = 0.1
90+
# Global seed
9391
seed: int = 42
9492

95-
# Train/validation split
96-
train_indices: Optional[List[int]] = None
97-
val_indices: Optional[List[int]] = None
98-
9993
# Caching
10094
cache_dir: Optional[str] = None
10195
use_cache: bool = False
102-
103-
# Data augmentation (future extension)
104-
augmentation: Dict[str, Any] = field(default_factory=dict)
10596

10697

10798
@dataclass
@@ -280,9 +271,14 @@ def to_yaml(self, yaml_path: Union[str, Path]) -> None:
280271

281272
def validate(self) -> None:
282273
"""Validate configuration values."""
283-
# Dataset validation
284-
if not os.path.exists(self.dataset.root_dir):
285-
raise ValueError(f"Dataset root directory does not exist: {self.dataset.root_dir}")
274+
# Dataset validation - check all train and eval datasets
275+
for split_name, datasets in [("train", self.dataset.train), ("eval", self.dataset.eval)]:
276+
for i, dataset_cfg in enumerate(datasets):
277+
root_dir = dataset_cfg.get('root_dir')
278+
if not root_dir:
279+
raise ValueError(f"Missing root_dir for {split_name} dataset {i}")
280+
if not os.path.exists(root_dir):
281+
raise ValueError(f"Dataset root directory does not exist: {root_dir}")
286282

287283
# Training validation
288284
if self.training.warmup_steps is not None and self.training.warmup_ratio > 0:
@@ -303,8 +299,16 @@ def validate(self) -> None:
303299
self.training.logging_dir = os.path.join(self.training.output_dir, "logs")
304300
Path(self.training.logging_dir).mkdir(parents=True, exist_ok=True)
305301

306-
def get_pipeline_steps(self, processor=None):
307-
"""Create actual pipeline step instances from configuration."""
302+
def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=None):
303+
"""Create actual pipeline step instances from pipeline configuration.
304+
305+
Args:
306+
pipeline_config: List of pipeline step configurations
307+
processor: The model processor (required for Tokenizer step)
308+
309+
Returns:
310+
List of initialized pipeline step instances
311+
"""
308312
from olmocr.train.dataloader import (
309313
FrontMatterParser,
310314
PDFRenderer,
@@ -317,14 +321,18 @@ def get_pipeline_steps(self, processor=None):
317321
from olmocr.prompts.prompts import PageResponse
318322

319323
steps = []
320-
for step_config in self.dataset.pipeline_steps:
324+
for step_config in pipeline_config:
321325
if not step_config.get('enabled', True):
322326
continue
323327

324328
step_name = step_config['name']
325329

326330
if step_name == 'FrontMatterParser':
327-
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
331+
# Handle both old and new config format
332+
if 'front_matter_class' in step_config:
333+
front_matter_class = PageResponse if step_config['front_matter_class'] == 'PageResponse' else None
334+
else:
335+
front_matter_class = PageResponse if step_config.get('use_page_response_class', True) else None
328336
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
329337

330338
elif step_name == 'PDFRenderer':
@@ -365,7 +373,7 @@ def create_default_config() -> Config:
365373
"""Create a default configuration."""
366374
return Config(
367375
model=ModelConfig(),
368-
dataset=DatasetConfig(root_dir="/path/to/dataset"),
376+
dataset=DatasetConfig(),
369377
training=TrainingConfig()
370378
)
371379

olmocr/train/configs/example_config.yaml

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ model:
2727
dataset:
2828

2929
train:
30-
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
31-
pipeline *basic_pipeline:
30+
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
31+
pipeline: &basic_pipeline
3232
- name: FrontMatterParser
3333
front_matter_class: PageResponse
3434
- name: PDFRenderer
@@ -41,15 +41,14 @@ dataset:
4141
- name: Tokenizer
4242
masking_index: -100
4343
end_of_message_token: "<|im_end|>"
44-
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_train_s2pdf/
45-
pipeline: *reuse basic_pipeline above*
44+
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
45+
pipeline: *basic_pipeline
4646

4747
eval:
4848
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
49-
pipeline: *reuse basic_pipeline above*
50-
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_00_books_eval_s2pdf/
51-
pipeline: *reuse basic_pipeline above*
52-
49+
pipeline: *basic_pipeline
50+
- root_dir: /home/ubuntu/olmOCR-mix-0225/processed_01_books_eval_iabooks/
51+
pipeline: *basic_pipeline
5352

5453

5554
# Training configuration

olmocr/train/train.py

Lines changed: 148 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,148 @@
1-
# TODO Overall, this code will read in a config yaml file with omega conf
2-
# From that config, we are going to use HuggingFace Trainer to train a model
3-
# TODOS:
4-
# DONE Build a script to convert olmocr-mix to a new dataloader format
5-
# DONE Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version
6-
# Get a basic config yaml file system working
7-
# Get a basic hugging face trainer running, supporting Qwen2.5VL for now
8-
# Saving and restoring training checkpoints
9-
# Converting training checkpoints to vllm compatible checkpoinst
1+
"""
2+
Simple script to test OlmOCR dataset loading with YAML configuration.
3+
"""
4+
5+
import argparse
6+
import logging
7+
from pathlib import Path
8+
from pprint import pprint
9+
10+
from transformers import AutoProcessor
11+
12+
from olmocr.train.config import Config
13+
from olmocr.train.dataloader import BaseMarkdownPDFDataset
14+
15+
# Configure logging
16+
logging.basicConfig(
17+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
18+
datefmt="%m/%d/%Y %H:%M:%S",
19+
level=logging.INFO,
20+
)
21+
logger = logging.getLogger(__name__)
22+
23+
24+
def print_sample(sample, dataset_name):
25+
"""Pretty print a dataset sample."""
26+
print(f"\n{'='*80}")
27+
print(f"Sample from: {dataset_name}")
28+
print(f"{'='*80}")
29+
30+
# Print keys
31+
print(f"\nAvailable keys: {list(sample.keys())}")
32+
33+
# Print path information
34+
if 'markdown_path' in sample:
35+
print(f"\nMarkdown path: {sample['markdown_path']}")
36+
if 'pdf_path' in sample:
37+
print(f"PDF path: {sample['pdf_path']}")
38+
39+
# Print page data
40+
if 'page_data' in sample:
41+
print(f"\nPage data:")
42+
print(f" Primary language: {sample['page_data'].primary_language}")
43+
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
44+
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
45+
print(f" Is table: {sample['page_data'].is_table}")
46+
print(f" Is diagram: {sample['page_data'].is_diagram}")
47+
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
48+
49+
# Print image info
50+
if 'image' in sample:
51+
print(f"\nImage shape: {sample['image'].size}")
52+
53+
# Print anchor text preview
54+
if 'anchor_text' in sample:
55+
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
56+
57+
# Print instruction prompt preview
58+
if 'instruction_prompt' in sample:
59+
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
60+
61+
# Print response preview
62+
if 'response' in sample:
63+
print(f"\nResponse preview: {sample['response'][:200]}...")
64+
65+
# Print tokenization info
66+
if 'input_ids' in sample:
67+
print(f"\nTokenization info:")
68+
print(f" Input IDs shape: {sample['input_ids'].shape}")
69+
print(f" Attention mask shape: {sample['attention_mask'].shape}")
70+
print(f" Labels shape: {sample['labels'].shape}")
71+
if 'pixel_values' in sample:
72+
print(f" Pixel values shape: {sample['pixel_values'].shape}")
73+
if 'image_grid_thw' in sample:
74+
print(f" Image grid THW: {sample['image_grid_thw']}")
75+
76+
77+
def main():
78+
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading")
79+
parser.add_argument(
80+
"--config",
81+
type=str,
82+
default="olmocr/train/configs/example_config.yaml",
83+
help="Path to YAML configuration file"
84+
)
85+
86+
args = parser.parse_args()
87+
88+
# Load configuration
89+
logger.info(f"Loading configuration from: {args.config}")
90+
config = Config.from_yaml(args.config)
91+
92+
# Validate configuration
93+
try:
94+
config.validate()
95+
except ValueError as e:
96+
logger.error(f"Configuration validation failed: {e}")
97+
return
98+
99+
# Load processor for tokenization
100+
logger.info(f"Loading processor: {config.model.name}")
101+
processor = AutoProcessor.from_pretrained(
102+
config.model.name,
103+
trust_remote_code=config.model.processor_trust_remote_code
104+
)
105+
106+
# Process training datasets
107+
print(f"\n{'='*80}")
108+
print("TRAINING DATASETS")
109+
print(f"{'='*80}")
110+
111+
for i, dataset_cfg in enumerate(config.dataset.train):
112+
root_dir = dataset_cfg['root_dir']
113+
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
114+
115+
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}")
116+
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
117+
logger.info(f"Found {len(dataset)} samples")
118+
119+
if len(dataset) > 0:
120+
# Get first sample
121+
sample = dataset[0]
122+
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
123+
124+
# Process evaluation datasets
125+
print(f"\n\n{'='*80}")
126+
print("EVALUATION DATASETS")
127+
print(f"{'='*80}")
128+
129+
for i, dataset_cfg in enumerate(config.dataset.eval):
130+
root_dir = dataset_cfg['root_dir']
131+
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
132+
133+
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}")
134+
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
135+
logger.info(f"Found {len(dataset)} samples")
136+
137+
if len(dataset) > 0:
138+
# Get first sample
139+
sample = dataset[0]
140+
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
141+
142+
print(f"\n{'='*80}")
143+
print("Dataset loading test completed!")
144+
print(f"{'='*80}")
145+
146+
147+
if __name__ == "__main__":
148+
main()

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ train = [
105105
"s3fs",
106106
"necessary",
107107
"einops",
108-
"transformers>=4.45.1"
109108
]
110109

111110
elo = [

0 commit comments

Comments
 (0)