Skip to content

Commit 91e7b5c

Browse files
committed
Claude generated train script
1 parent 0ebc35c commit 91e7b5c

File tree

1 file changed

+150
-74
lines changed

1 file changed

+150
-74
lines changed

olmocr/train/train.py

Lines changed: 150 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@
44

55
import argparse
66
import logging
7-
from pathlib import Path
8-
from pprint import pprint
97

10-
from transformers import AutoProcessor
8+
from transformers import (
9+
AutoProcessor,
10+
Qwen2VLForConditionalGeneration,
11+
Trainer,
12+
TrainingArguments,
13+
EarlyStoppingCallback
14+
)
15+
import torch
16+
from torch.utils.data import ConcatDataset
1117

1218
from olmocr.train.config import Config
1319
from olmocr.train.dataloader import BaseMarkdownPDFDataset
@@ -21,61 +27,42 @@
2127
logger = logging.getLogger(__name__)
2228

2329

24-
def print_sample(sample, dataset_name):
25-
"""Pretty print a dataset sample."""
26-
print(f"\n{'='*80}")
27-
print(f"Sample from: {dataset_name}")
28-
print(f"{'='*80}")
29-
30-
# Print keys
31-
print(f"\nAvailable keys: {list(sample.keys())}")
32-
33-
# Print path information
34-
if 'markdown_path' in sample:
35-
print(f"\nMarkdown path: {sample['markdown_path']}")
36-
if 'pdf_path' in sample:
37-
print(f"PDF path: {sample['pdf_path']}")
38-
39-
# Print page data
40-
if 'page_data' in sample:
41-
print(f"\nPage data:")
42-
print(f" Primary language: {sample['page_data'].primary_language}")
43-
print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}")
44-
print(f" Rotation correction: {sample['page_data'].rotation_correction}")
45-
print(f" Is table: {sample['page_data'].is_table}")
46-
print(f" Is diagram: {sample['page_data'].is_diagram}")
47-
print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None")
48-
49-
# Print image info
50-
if 'image' in sample:
51-
print(f"\nImage shape: {sample['image'].size}")
52-
53-
# Print anchor text preview
54-
if 'anchor_text' in sample:
55-
print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...")
56-
57-
# Print instruction prompt preview
58-
if 'instruction_prompt' in sample:
59-
print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...")
60-
61-
# Print response preview
62-
if 'response' in sample:
63-
print(f"\nResponse preview: {sample['response'][:200]}...")
64-
65-
# Print tokenization info
66-
if 'input_ids' in sample:
67-
print(f"\nTokenization info:")
68-
print(f" Input IDs shape: {sample['input_ids'].shape}")
69-
print(f" Attention mask shape: {sample['attention_mask'].shape}")
70-
print(f" Labels shape: {sample['labels'].shape}")
71-
if 'pixel_values' in sample:
72-
print(f" Pixel values shape: {sample['pixel_values'].shape}")
73-
if 'image_grid_thw' in sample:
74-
print(f" Image grid THW: {sample['image_grid_thw']}")
30+
def create_data_collator():
31+
"""Create a data collator for vision-language models."""
32+
def collate_fn(examples):
33+
# Filter out None values and extract the fields we need
34+
batch = {
35+
'input_ids': [],
36+
'attention_mask': [],
37+
'labels': [],
38+
'pixel_values': [],
39+
'image_grid_thw': []
40+
}
41+
42+
for example in examples:
43+
if example is not None:
44+
batch['input_ids'].append(example['input_ids'])
45+
batch['attention_mask'].append(example['attention_mask'])
46+
batch['labels'].append(example['labels'])
47+
batch['pixel_values'].append(example['pixel_values'])
48+
batch['image_grid_thw'].append(example['image_grid_thw'])
49+
50+
# Convert lists to tensors with proper padding
51+
# Note: For Qwen2-VL, we typically handle variable length sequences
52+
# The model's processor should handle the padding internally
53+
return {
54+
'input_ids': torch.stack(batch['input_ids']),
55+
'attention_mask': torch.stack(batch['attention_mask']),
56+
'labels': torch.stack(batch['labels']),
57+
'pixel_values': batch['pixel_values'], # Keep as list for now
58+
'image_grid_thw': torch.stack(batch['image_grid_thw'])
59+
}
60+
61+
return collate_fn
7562

7663

7764
def main():
78-
parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading")
65+
parser = argparse.ArgumentParser(description="Train OlmOCR model")
7966
parser.add_argument(
8067
"--config",
8168
type=str,
@@ -103,45 +90,134 @@ def main():
10390
trust_remote_code=config.model.processor_trust_remote_code
10491
)
10592

106-
# Process training datasets
107-
print(f"\n{'='*80}")
108-
print("TRAINING DATASETS")
109-
print(f"{'='*80}")
93+
# Load model
94+
logger.info(f"Loading model: {config.model.name}")
95+
model = Qwen2VLForConditionalGeneration.from_pretrained(
96+
config.model.name,
97+
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
98+
device_map=config.model.device_map,
99+
trust_remote_code=config.model.trust_remote_code,
100+
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
101+
)
102+
103+
# Enable gradient checkpointing if configured
104+
if config.training.gradient_checkpointing:
105+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
110106

107+
# Create training datasets
108+
logger.info("Creating training datasets...")
109+
train_datasets = []
111110
for i, dataset_cfg in enumerate(config.dataset.train):
112111
root_dir = dataset_cfg['root_dir']
113112
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
114113

115-
logger.info(f"\nCreating training dataset {i+1} from: {root_dir}")
114+
logger.info(f"Creating training dataset {i+1} from: {root_dir}")
116115
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
117116
logger.info(f"Found {len(dataset)} samples")
118117

119118
if len(dataset) > 0:
120-
# Get first sample
121-
sample = dataset[0]
122-
print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}")
119+
train_datasets.append(dataset)
123120

124-
# Process evaluation datasets
125-
print(f"\n\n{'='*80}")
126-
print("EVALUATION DATASETS")
127-
print(f"{'='*80}")
121+
# Combine all training datasets
122+
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
123+
logger.info(f"Total training samples: {len(train_dataset)}")
128124

125+
# Create evaluation datasets
126+
logger.info("Creating evaluation datasets...")
127+
eval_datasets = []
129128
for i, dataset_cfg in enumerate(config.dataset.eval):
130129
root_dir = dataset_cfg['root_dir']
131130
pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
132131

133-
logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}")
132+
logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}")
134133
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
135134
logger.info(f"Found {len(dataset)} samples")
136135

137136
if len(dataset) > 0:
138-
# Get first sample
139-
sample = dataset[0]
140-
print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}")
137+
eval_datasets.append(dataset)
138+
139+
# Combine all evaluation datasets
140+
eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0]
141+
logger.info(f"Total evaluation samples: {len(eval_dataset)}")
142+
143+
# Set up training arguments
144+
training_args = TrainingArguments(
145+
output_dir=config.training.output_dir,
146+
num_train_epochs=config.training.num_train_epochs,
147+
per_device_train_batch_size=config.training.per_device_train_batch_size,
148+
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
149+
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
150+
learning_rate=config.training.learning_rate,
151+
lr_scheduler_type=config.training.lr_scheduler_type,
152+
warmup_ratio=config.training.warmup_ratio,
153+
warmup_steps=config.training.warmup_steps,
154+
optim=config.training.optim,
155+
adam_beta1=config.training.adam_beta1,
156+
adam_beta2=config.training.adam_beta2,
157+
adam_epsilon=config.training.adam_epsilon,
158+
weight_decay=config.training.weight_decay,
159+
max_grad_norm=config.training.max_grad_norm,
160+
fp16=config.training.fp16,
161+
bf16=config.training.bf16,
162+
tf32=config.training.tf32,
163+
eval_strategy=config.training.evaluation_strategy,
164+
eval_steps=config.training.eval_steps,
165+
save_strategy=config.training.save_strategy,
166+
save_steps=config.training.save_steps,
167+
save_total_limit=config.training.save_total_limit,
168+
load_best_model_at_end=config.training.load_best_model_at_end,
169+
metric_for_best_model=config.training.metric_for_best_model,
170+
greater_is_better=config.training.greater_is_better,
171+
logging_dir=config.training.logging_dir,
172+
logging_strategy=config.training.logging_strategy,
173+
logging_steps=config.training.logging_steps,
174+
logging_first_step=config.training.logging_first_step,
175+
report_to=config.training.report_to,
176+
seed=config.training.seed,
177+
data_seed=config.training.data_seed,
178+
push_to_hub=config.training.push_to_hub,
179+
hub_model_id=config.training.hub_model_id,
180+
hub_strategy=config.training.hub_strategy,
181+
resume_from_checkpoint=config.training.resume_from_checkpoint,
182+
deepspeed=config.training.deepspeed,
183+
dataloader_drop_last=config.training.dataloader_drop_last,
184+
dataloader_num_workers=config.training.dataloader_num_workers,
185+
remove_unused_columns=config.training.remove_unused_columns,
186+
run_name=config.run_name,
187+
)
188+
189+
# Set up callbacks
190+
callbacks = []
191+
if config.training.use_early_stopping:
192+
callbacks.append(
193+
EarlyStoppingCallback(
194+
early_stopping_patience=config.training.early_stopping_patience,
195+
early_stopping_threshold=config.training.early_stopping_threshold
196+
)
197+
)
198+
199+
# Initialize trainer
200+
logger.info("Initializing trainer...")
201+
trainer = Trainer(
202+
model=model,
203+
args=training_args,
204+
train_dataset=train_dataset,
205+
eval_dataset=eval_dataset,
206+
data_collator=create_data_collator(),
207+
callbacks=callbacks,
208+
)
209+
210+
# Start training
211+
logger.info("Starting training...")
212+
train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint)
213+
214+
# Save the final model
215+
logger.info("Saving final model...")
216+
trainer.save_model()
217+
trainer.save_state()
141218

142-
print(f"\n{'='*80}")
143-
print("Dataset loading test completed!")
144-
print(f"{'='*80}")
219+
# Log metrics
220+
logger.info(f"Training completed! Metrics: {train_result.metrics}")
145221

146222

147223
if __name__ == "__main__":

0 commit comments

Comments
 (0)