1
- # TODO Overall, this code will read in a config yaml file with omega conf
2
- # From that config, we are going to use HuggingFace Trainer to train a model
3
- # TODOS:
4
- # DONE Build a script to convert olmocr-mix to a new dataloader format
5
- # DONE Write a new dataloader and collator, with tests that brings in everything, only needs to support batch size 1 for this first version
6
- # Get a basic config yaml file system working
7
- # Get a basic hugging face trainer running, supporting Qwen2.5VL for now
8
- # Saving and restoring training checkpoints
9
- # Converting training checkpoints to vllm compatible checkpoinst
1
+ """
2
+ Simple script to test OlmOCR dataset loading with YAML configuration.
3
+ """
4
+
5
+ import argparse
6
+ import logging
7
+ from pathlib import Path
8
+ from pprint import pprint
9
+
10
+ from transformers import AutoProcessor
11
+
12
+ from olmocr .train .config import Config
13
+ from olmocr .train .dataloader import BaseMarkdownPDFDataset
14
+
15
+ # Configure logging
16
+ logging .basicConfig (
17
+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
18
+ datefmt = "%m/%d/%Y %H:%M:%S" ,
19
+ level = logging .INFO ,
20
+ )
21
+ logger = logging .getLogger (__name__ )
22
+
23
+
24
+ def print_sample (sample , dataset_name ):
25
+ """Pretty print a dataset sample."""
26
+ print (f"\n { '=' * 80 } " )
27
+ print (f"Sample from: { dataset_name } " )
28
+ print (f"{ '=' * 80 } " )
29
+
30
+ # Print keys
31
+ print (f"\n Available keys: { list (sample .keys ())} " )
32
+
33
+ # Print path information
34
+ if 'markdown_path' in sample :
35
+ print (f"\n Markdown path: { sample ['markdown_path' ]} " )
36
+ if 'pdf_path' in sample :
37
+ print (f"PDF path: { sample ['pdf_path' ]} " )
38
+
39
+ # Print page data
40
+ if 'page_data' in sample :
41
+ print (f"\n Page data:" )
42
+ print (f" Primary language: { sample ['page_data' ].primary_language } " )
43
+ print (f" Is rotation valid: { sample ['page_data' ].is_rotation_valid } " )
44
+ print (f" Rotation correction: { sample ['page_data' ].rotation_correction } " )
45
+ print (f" Is table: { sample ['page_data' ].is_table } " )
46
+ print (f" Is diagram: { sample ['page_data' ].is_diagram } " )
47
+ print (f" Natural text preview: { sample ['page_data' ].natural_text [:200 ]} ..." if sample ['page_data' ].natural_text else " Natural text: None" )
48
+
49
+ # Print image info
50
+ if 'image' in sample :
51
+ print (f"\n Image shape: { sample ['image' ].size } " )
52
+
53
+ # Print anchor text preview
54
+ if 'anchor_text' in sample :
55
+ print (f"\n Anchor text preview: { sample ['anchor_text' ][:200 ]} ..." )
56
+
57
+ # Print instruction prompt preview
58
+ if 'instruction_prompt' in sample :
59
+ print (f"\n Instruction prompt preview: { sample ['instruction_prompt' ][:200 ]} ..." )
60
+
61
+ # Print response preview
62
+ if 'response' in sample :
63
+ print (f"\n Response preview: { sample ['response' ][:200 ]} ..." )
64
+
65
+ # Print tokenization info
66
+ if 'input_ids' in sample :
67
+ print (f"\n Tokenization info:" )
68
+ print (f" Input IDs shape: { sample ['input_ids' ].shape } " )
69
+ print (f" Attention mask shape: { sample ['attention_mask' ].shape } " )
70
+ print (f" Labels shape: { sample ['labels' ].shape } " )
71
+ if 'pixel_values' in sample :
72
+ print (f" Pixel values shape: { sample ['pixel_values' ].shape } " )
73
+ if 'image_grid_thw' in sample :
74
+ print (f" Image grid THW: { sample ['image_grid_thw' ]} " )
75
+
76
+
77
+ def main ():
78
+ parser = argparse .ArgumentParser (description = "Test OlmOCR dataset loading" )
79
+ parser .add_argument (
80
+ "--config" ,
81
+ type = str ,
82
+ default = "olmocr/train/configs/example_config.yaml" ,
83
+ help = "Path to YAML configuration file"
84
+ )
85
+
86
+ args = parser .parse_args ()
87
+
88
+ # Load configuration
89
+ logger .info (f"Loading configuration from: { args .config } " )
90
+ config = Config .from_yaml (args .config )
91
+
92
+ # Validate configuration
93
+ try :
94
+ config .validate ()
95
+ except ValueError as e :
96
+ logger .error (f"Configuration validation failed: { e } " )
97
+ return
98
+
99
+ # Load processor for tokenization
100
+ logger .info (f"Loading processor: { config .model .name } " )
101
+ processor = AutoProcessor .from_pretrained (
102
+ config .model .name ,
103
+ trust_remote_code = config .model .processor_trust_remote_code
104
+ )
105
+
106
+ # Process training datasets
107
+ print (f"\n { '=' * 80 } " )
108
+ print ("TRAINING DATASETS" )
109
+ print (f"{ '=' * 80 } " )
110
+
111
+ for i , dataset_cfg in enumerate (config .dataset .train ):
112
+ root_dir = dataset_cfg ['root_dir' ]
113
+ pipeline_steps = config .get_pipeline_steps (dataset_cfg ['pipeline' ], processor )
114
+
115
+ logger .info (f"\n Creating training dataset { i + 1 } from: { root_dir } " )
116
+ dataset = BaseMarkdownPDFDataset (root_dir , pipeline_steps )
117
+ logger .info (f"Found { len (dataset )} samples" )
118
+
119
+ if len (dataset ) > 0 :
120
+ # Get first sample
121
+ sample = dataset [0 ]
122
+ print_sample (sample , f"Training Dataset { i + 1 } : { Path (root_dir ).name } " )
123
+
124
+ # Process evaluation datasets
125
+ print (f"\n \n { '=' * 80 } " )
126
+ print ("EVALUATION DATASETS" )
127
+ print (f"{ '=' * 80 } " )
128
+
129
+ for i , dataset_cfg in enumerate (config .dataset .eval ):
130
+ root_dir = dataset_cfg ['root_dir' ]
131
+ pipeline_steps = config .get_pipeline_steps (dataset_cfg ['pipeline' ], processor )
132
+
133
+ logger .info (f"\n Creating evaluation dataset { i + 1 } from: { root_dir } " )
134
+ dataset = BaseMarkdownPDFDataset (root_dir , pipeline_steps )
135
+ logger .info (f"Found { len (dataset )} samples" )
136
+
137
+ if len (dataset ) > 0 :
138
+ # Get first sample
139
+ sample = dataset [0 ]
140
+ print_sample (sample , f"Evaluation Dataset { i + 1 } : { Path (root_dir ).name } " )
141
+
142
+ print (f"\n { '=' * 80 } " )
143
+ print ("Dataset loading test completed!" )
144
+ print (f"{ '=' * 80 } " )
145
+
146
+
147
+ if __name__ == "__main__" :
148
+ main ()
0 commit comments