|
4 | 4 |
|
5 | 5 | import argparse
|
6 | 6 | import logging
|
7 |
| -from pathlib import Path |
8 |
| -from pprint import pprint |
9 | 7 |
|
10 |
| -from transformers import AutoProcessor |
| 8 | +from transformers import ( |
| 9 | + AutoProcessor, |
| 10 | + Qwen2VLForConditionalGeneration, |
| 11 | + Trainer, |
| 12 | + TrainingArguments, |
| 13 | + EarlyStoppingCallback |
| 14 | +) |
| 15 | +import torch |
| 16 | +from torch.utils.data import ConcatDataset |
11 | 17 |
|
12 | 18 | from olmocr.train.config import Config
|
13 | 19 | from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
|
21 | 27 | logger = logging.getLogger(__name__)
|
22 | 28 |
|
23 | 29 |
|
24 |
| -def print_sample(sample, dataset_name): |
25 |
| - """Pretty print a dataset sample.""" |
26 |
| - print(f"\n{'='*80}") |
27 |
| - print(f"Sample from: {dataset_name}") |
28 |
| - print(f"{'='*80}") |
29 |
| - |
30 |
| - # Print keys |
31 |
| - print(f"\nAvailable keys: {list(sample.keys())}") |
32 |
| - |
33 |
| - # Print path information |
34 |
| - if 'markdown_path' in sample: |
35 |
| - print(f"\nMarkdown path: {sample['markdown_path']}") |
36 |
| - if 'pdf_path' in sample: |
37 |
| - print(f"PDF path: {sample['pdf_path']}") |
38 |
| - |
39 |
| - # Print page data |
40 |
| - if 'page_data' in sample: |
41 |
| - print(f"\nPage data:") |
42 |
| - print(f" Primary language: {sample['page_data'].primary_language}") |
43 |
| - print(f" Is rotation valid: {sample['page_data'].is_rotation_valid}") |
44 |
| - print(f" Rotation correction: {sample['page_data'].rotation_correction}") |
45 |
| - print(f" Is table: {sample['page_data'].is_table}") |
46 |
| - print(f" Is diagram: {sample['page_data'].is_diagram}") |
47 |
| - print(f" Natural text preview: {sample['page_data'].natural_text[:200]}..." if sample['page_data'].natural_text else " Natural text: None") |
48 |
| - |
49 |
| - # Print image info |
50 |
| - if 'image' in sample: |
51 |
| - print(f"\nImage shape: {sample['image'].size}") |
52 |
| - |
53 |
| - # Print anchor text preview |
54 |
| - if 'anchor_text' in sample: |
55 |
| - print(f"\nAnchor text preview: {sample['anchor_text'][:200]}...") |
56 |
| - |
57 |
| - # Print instruction prompt preview |
58 |
| - if 'instruction_prompt' in sample: |
59 |
| - print(f"\nInstruction prompt preview: {sample['instruction_prompt'][:200]}...") |
60 |
| - |
61 |
| - # Print response preview |
62 |
| - if 'response' in sample: |
63 |
| - print(f"\nResponse preview: {sample['response'][:200]}...") |
64 |
| - |
65 |
| - # Print tokenization info |
66 |
| - if 'input_ids' in sample: |
67 |
| - print(f"\nTokenization info:") |
68 |
| - print(f" Input IDs shape: {sample['input_ids'].shape}") |
69 |
| - print(f" Attention mask shape: {sample['attention_mask'].shape}") |
70 |
| - print(f" Labels shape: {sample['labels'].shape}") |
71 |
| - if 'pixel_values' in sample: |
72 |
| - print(f" Pixel values shape: {sample['pixel_values'].shape}") |
73 |
| - if 'image_grid_thw' in sample: |
74 |
| - print(f" Image grid THW: {sample['image_grid_thw']}") |
| 30 | +def create_data_collator(): |
| 31 | + """Create a data collator for vision-language models.""" |
| 32 | + def collate_fn(examples): |
| 33 | + # Filter out None values and extract the fields we need |
| 34 | + batch = { |
| 35 | + 'input_ids': [], |
| 36 | + 'attention_mask': [], |
| 37 | + 'labels': [], |
| 38 | + 'pixel_values': [], |
| 39 | + 'image_grid_thw': [] |
| 40 | + } |
| 41 | + |
| 42 | + for example in examples: |
| 43 | + if example is not None: |
| 44 | + batch['input_ids'].append(example['input_ids']) |
| 45 | + batch['attention_mask'].append(example['attention_mask']) |
| 46 | + batch['labels'].append(example['labels']) |
| 47 | + batch['pixel_values'].append(example['pixel_values']) |
| 48 | + batch['image_grid_thw'].append(example['image_grid_thw']) |
| 49 | + |
| 50 | + # Convert lists to tensors with proper padding |
| 51 | + # Note: For Qwen2-VL, we typically handle variable length sequences |
| 52 | + # The model's processor should handle the padding internally |
| 53 | + return { |
| 54 | + 'input_ids': torch.stack(batch['input_ids']), |
| 55 | + 'attention_mask': torch.stack(batch['attention_mask']), |
| 56 | + 'labels': torch.stack(batch['labels']), |
| 57 | + 'pixel_values': batch['pixel_values'], # Keep as list for now |
| 58 | + 'image_grid_thw': torch.stack(batch['image_grid_thw']) |
| 59 | + } |
| 60 | + |
| 61 | + return collate_fn |
75 | 62 |
|
76 | 63 |
|
77 | 64 | def main():
|
78 |
| - parser = argparse.ArgumentParser(description="Test OlmOCR dataset loading") |
| 65 | + parser = argparse.ArgumentParser(description="Train OlmOCR model") |
79 | 66 | parser.add_argument(
|
80 | 67 | "--config",
|
81 | 68 | type=str,
|
@@ -103,45 +90,134 @@ def main():
|
103 | 90 | trust_remote_code=config.model.processor_trust_remote_code
|
104 | 91 | )
|
105 | 92 |
|
106 |
| - # Process training datasets |
107 |
| - print(f"\n{'='*80}") |
108 |
| - print("TRAINING DATASETS") |
109 |
| - print(f"{'='*80}") |
| 93 | + # Load model |
| 94 | + logger.info(f"Loading model: {config.model.name}") |
| 95 | + model = Qwen2VLForConditionalGeneration.from_pretrained( |
| 96 | + config.model.name, |
| 97 | + torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", |
| 98 | + device_map=config.model.device_map, |
| 99 | + trust_remote_code=config.model.trust_remote_code, |
| 100 | + attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, |
| 101 | + ) |
| 102 | + |
| 103 | + # Enable gradient checkpointing if configured |
| 104 | + if config.training.gradient_checkpointing: |
| 105 | + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs) |
110 | 106 |
|
| 107 | + # Create training datasets |
| 108 | + logger.info("Creating training datasets...") |
| 109 | + train_datasets = [] |
111 | 110 | for i, dataset_cfg in enumerate(config.dataset.train):
|
112 | 111 | root_dir = dataset_cfg['root_dir']
|
113 | 112 | pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
114 | 113 |
|
115 |
| - logger.info(f"\nCreating training dataset {i+1} from: {root_dir}") |
| 114 | + logger.info(f"Creating training dataset {i+1} from: {root_dir}") |
116 | 115 | dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
117 | 116 | logger.info(f"Found {len(dataset)} samples")
|
118 | 117 |
|
119 | 118 | if len(dataset) > 0:
|
120 |
| - # Get first sample |
121 |
| - sample = dataset[0] |
122 |
| - print_sample(sample, f"Training Dataset {i+1}: {Path(root_dir).name}") |
| 119 | + train_datasets.append(dataset) |
123 | 120 |
|
124 |
| - # Process evaluation datasets |
125 |
| - print(f"\n\n{'='*80}") |
126 |
| - print("EVALUATION DATASETS") |
127 |
| - print(f"{'='*80}") |
| 121 | + # Combine all training datasets |
| 122 | + train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0] |
| 123 | + logger.info(f"Total training samples: {len(train_dataset)}") |
128 | 124 |
|
| 125 | + # Create evaluation datasets |
| 126 | + logger.info("Creating evaluation datasets...") |
| 127 | + eval_datasets = [] |
129 | 128 | for i, dataset_cfg in enumerate(config.dataset.eval):
|
130 | 129 | root_dir = dataset_cfg['root_dir']
|
131 | 130 | pipeline_steps = config.get_pipeline_steps(dataset_cfg['pipeline'], processor)
|
132 | 131 |
|
133 |
| - logger.info(f"\nCreating evaluation dataset {i+1} from: {root_dir}") |
| 132 | + logger.info(f"Creating evaluation dataset {i+1} from: {root_dir}") |
134 | 133 | dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
|
135 | 134 | logger.info(f"Found {len(dataset)} samples")
|
136 | 135 |
|
137 | 136 | if len(dataset) > 0:
|
138 |
| - # Get first sample |
139 |
| - sample = dataset[0] |
140 |
| - print_sample(sample, f"Evaluation Dataset {i+1}: {Path(root_dir).name}") |
| 137 | + eval_datasets.append(dataset) |
| 138 | + |
| 139 | + # Combine all evaluation datasets |
| 140 | + eval_dataset = ConcatDataset(eval_datasets) if len(eval_datasets) > 1 else eval_datasets[0] |
| 141 | + logger.info(f"Total evaluation samples: {len(eval_dataset)}") |
| 142 | + |
| 143 | + # Set up training arguments |
| 144 | + training_args = TrainingArguments( |
| 145 | + output_dir=config.training.output_dir, |
| 146 | + num_train_epochs=config.training.num_train_epochs, |
| 147 | + per_device_train_batch_size=config.training.per_device_train_batch_size, |
| 148 | + per_device_eval_batch_size=config.training.per_device_eval_batch_size, |
| 149 | + gradient_accumulation_steps=config.training.gradient_accumulation_steps, |
| 150 | + learning_rate=config.training.learning_rate, |
| 151 | + lr_scheduler_type=config.training.lr_scheduler_type, |
| 152 | + warmup_ratio=config.training.warmup_ratio, |
| 153 | + warmup_steps=config.training.warmup_steps, |
| 154 | + optim=config.training.optim, |
| 155 | + adam_beta1=config.training.adam_beta1, |
| 156 | + adam_beta2=config.training.adam_beta2, |
| 157 | + adam_epsilon=config.training.adam_epsilon, |
| 158 | + weight_decay=config.training.weight_decay, |
| 159 | + max_grad_norm=config.training.max_grad_norm, |
| 160 | + fp16=config.training.fp16, |
| 161 | + bf16=config.training.bf16, |
| 162 | + tf32=config.training.tf32, |
| 163 | + eval_strategy=config.training.evaluation_strategy, |
| 164 | + eval_steps=config.training.eval_steps, |
| 165 | + save_strategy=config.training.save_strategy, |
| 166 | + save_steps=config.training.save_steps, |
| 167 | + save_total_limit=config.training.save_total_limit, |
| 168 | + load_best_model_at_end=config.training.load_best_model_at_end, |
| 169 | + metric_for_best_model=config.training.metric_for_best_model, |
| 170 | + greater_is_better=config.training.greater_is_better, |
| 171 | + logging_dir=config.training.logging_dir, |
| 172 | + logging_strategy=config.training.logging_strategy, |
| 173 | + logging_steps=config.training.logging_steps, |
| 174 | + logging_first_step=config.training.logging_first_step, |
| 175 | + report_to=config.training.report_to, |
| 176 | + seed=config.training.seed, |
| 177 | + data_seed=config.training.data_seed, |
| 178 | + push_to_hub=config.training.push_to_hub, |
| 179 | + hub_model_id=config.training.hub_model_id, |
| 180 | + hub_strategy=config.training.hub_strategy, |
| 181 | + resume_from_checkpoint=config.training.resume_from_checkpoint, |
| 182 | + deepspeed=config.training.deepspeed, |
| 183 | + dataloader_drop_last=config.training.dataloader_drop_last, |
| 184 | + dataloader_num_workers=config.training.dataloader_num_workers, |
| 185 | + remove_unused_columns=config.training.remove_unused_columns, |
| 186 | + run_name=config.run_name, |
| 187 | + ) |
| 188 | + |
| 189 | + # Set up callbacks |
| 190 | + callbacks = [] |
| 191 | + if config.training.use_early_stopping: |
| 192 | + callbacks.append( |
| 193 | + EarlyStoppingCallback( |
| 194 | + early_stopping_patience=config.training.early_stopping_patience, |
| 195 | + early_stopping_threshold=config.training.early_stopping_threshold |
| 196 | + ) |
| 197 | + ) |
| 198 | + |
| 199 | + # Initialize trainer |
| 200 | + logger.info("Initializing trainer...") |
| 201 | + trainer = Trainer( |
| 202 | + model=model, |
| 203 | + args=training_args, |
| 204 | + train_dataset=train_dataset, |
| 205 | + eval_dataset=eval_dataset, |
| 206 | + data_collator=create_data_collator(), |
| 207 | + callbacks=callbacks, |
| 208 | + ) |
| 209 | + |
| 210 | + # Start training |
| 211 | + logger.info("Starting training...") |
| 212 | + train_result = trainer.train(resume_from_checkpoint=config.training.resume_from_checkpoint) |
| 213 | + |
| 214 | + # Save the final model |
| 215 | + logger.info("Saving final model...") |
| 216 | + trainer.save_model() |
| 217 | + trainer.save_state() |
141 | 218 |
|
142 |
| - print(f"\n{'='*80}") |
143 |
| - print("Dataset loading test completed!") |
144 |
| - print(f"{'='*80}") |
| 219 | + # Log metrics |
| 220 | + logger.info(f"Training completed! Metrics: {train_result.metrics}") |
145 | 221 |
|
146 | 222 |
|
147 | 223 | if __name__ == "__main__":
|
|
0 commit comments