Skip to content

Commit a5a0cd7

Browse files
committed
Trying a few more configs
1 parent 384a1b1 commit a5a0cd7

5 files changed

+342
-0
lines changed

olmocr/train/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ class InstructUserMessagesConfig(PipelineStepConfig):
8383
name: str = "InstructUserMessages"
8484

8585

86+
@dataclass
87+
class LatexBracketNormalizerConfig(PipelineStepConfig):
88+
"""Configuration for LatexBracketNormalizer step."""
89+
90+
name: str = "LatexBracketNormalizer"
91+
92+
8693
@dataclass
8794
class TokenizerStepConfig(PipelineStepConfig):
8895
"""Configuration for Tokenizer step."""
@@ -307,6 +314,7 @@ def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=No
307314
FrontMatterOutputFormat,
308315
FrontMatterParser,
309316
InstructUserMessages,
317+
LatexBracketNormalizer,
310318
NewYamlFinetuningPromptWithAnchoring,
311319
NewYamlFinetuningPromptWithNoAnchoring,
312320
JSONOutputFormat,
@@ -356,6 +364,9 @@ def get_pipeline_steps(self, pipeline_config: List[Dict[str, Any]], processor=No
356364
elif step_name == "InstructUserMessages":
357365
steps.append(InstructUserMessages())
358366

367+
elif step_name == "LatexBracketNormalizer":
368+
steps.append(LatexBracketNormalizer())
369+
359370
elif step_name == "Tokenizer":
360371
if processor is None:
361372
raise ValueError("Processor must be provided for Tokenizer step")
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Example OlmOCR Training Configuration
2+
3+
# Project metadata
4+
project_name: olmocr-qwen-vl-training
5+
run_name: qwen2.5-vl-7b-finetune-day3-yaml-1280-noanchor-latexnormalize
6+
7+
# Model configuration
8+
model:
9+
name: Qwen/Qwen2.5-VL-7B-Instruct
10+
trust_remote_code: true
11+
torch_dtype: bfloat16
12+
use_flash_attention: true
13+
attn_implementation: flash_attention_2
14+
15+
# LoRA settings (disabled by default)
16+
use_lora: false
17+
# lora_rank: 8
18+
# lora_alpha: 32
19+
# lora_dropout: 0.1
20+
# lora_target_modules:
21+
# - q_proj
22+
# - v_proj
23+
# - k_proj
24+
# - o_proj
25+
26+
# Dataset configuration
27+
dataset:
28+
29+
train:
30+
- name: processed_01_books_train_iabooks
31+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
32+
pipeline: &basic_pipeline
33+
- name: FrontMatterParser
34+
front_matter_class: PageResponse
35+
- name: PDFRenderer
36+
target_longest_image_dim: 1280
37+
- name: LatexBracketNormalizer
38+
- name: StaticLengthDocumentAnchoring
39+
target_anchor_text_len: -1
40+
- name: FinetuningPrompt
41+
- name: FrontMatterOutputFormat
42+
- name: InstructUserMessages
43+
- name: Tokenizer
44+
masking_index: -100
45+
end_of_message_token: "<|im_end|>"
46+
- name: processed_00_documents_train_s2pdf
47+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
48+
pipeline: *basic_pipeline
49+
50+
eval:
51+
- name: processed_00_documents_eval_s2pdf
52+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
53+
pipeline: *basic_pipeline
54+
- name: processed_01_books_eval_iabooks
55+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
56+
pipeline: *basic_pipeline
57+
58+
59+
60+
# Training configuration
61+
training:
62+
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
63+
num_train_epochs: 1
64+
65+
# Batch size and accumulation
66+
per_device_train_batch_size: 1
67+
per_device_eval_batch_size: 1
68+
gradient_accumulation_steps: 32
69+
70+
gradient_checkpointing: False
71+
72+
collator_max_token_len: 8192
73+
74+
# Learning rate
75+
learning_rate: 2e-5
76+
lr_scheduler_type: linear
77+
warmup_ratio: 0.1
78+
79+
# Optimization
80+
optim: adamw_torch
81+
weight_decay: 0.01
82+
max_grad_norm: 1.0
83+
84+
85+
# Evaluation and checkpointing
86+
evaluation_strategy: steps
87+
eval_steps: 500
88+
save_strategy: steps
89+
save_steps: 500
90+
save_total_limit: 5
91+
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
92+
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
93+
greater_is_better: false
94+
95+
report_to:
96+
- wandb
97+
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Example OlmOCR Training Configuration
2+
3+
# Project metadata
4+
project_name: olmocr-qwen-vl-training
5+
run_name: qwen2.5-vl-7b-finetune-day3-yaml-1280-noanchor-newprompt
6+
7+
# Model configuration
8+
model:
9+
name: Qwen/Qwen2.5-VL-7B-Instruct
10+
trust_remote_code: true
11+
torch_dtype: bfloat16
12+
use_flash_attention: true
13+
attn_implementation: flash_attention_2
14+
15+
# LoRA settings (disabled by default)
16+
use_lora: false
17+
# lora_rank: 8
18+
# lora_alpha: 32
19+
# lora_dropout: 0.1
20+
# lora_target_modules:
21+
# - q_proj
22+
# - v_proj
23+
# - k_proj
24+
# - o_proj
25+
26+
# Dataset configuration
27+
dataset:
28+
29+
train:
30+
- name: processed_01_books_train_iabooks
31+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
32+
pipeline: &basic_pipeline
33+
- name: FrontMatterParser
34+
front_matter_class: PageResponse
35+
- name: PDFRenderer
36+
target_longest_image_dim: 1280
37+
- name: NewYamlFinetuningPromptWithNoAnchoring
38+
- name: FrontMatterOutputFormat
39+
- name: InstructUserMessages
40+
- name: Tokenizer
41+
masking_index: -100
42+
end_of_message_token: "<|im_end|>"
43+
- name: processed_00_documents_train_s2pdf
44+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
45+
pipeline: *basic_pipeline
46+
47+
eval:
48+
- name: processed_00_documents_eval_s2pdf
49+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
50+
pipeline: *basic_pipeline
51+
- name: processed_01_books_eval_iabooks
52+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
53+
pipeline: *basic_pipeline
54+
55+
56+
57+
# Training configuration
58+
training:
59+
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
60+
num_train_epochs: 1
61+
62+
# Batch size and accumulation
63+
per_device_train_batch_size: 1
64+
per_device_eval_batch_size: 1
65+
gradient_accumulation_steps: 32
66+
67+
gradient_checkpointing: False
68+
69+
collator_max_token_len: 8192
70+
71+
# Learning rate
72+
learning_rate: 2e-5
73+
lr_scheduler_type: linear
74+
warmup_ratio: 0.1
75+
76+
# Optimization
77+
optim: adamw_torch
78+
weight_decay: 0.01
79+
max_grad_norm: 1.0
80+
81+
82+
# Evaluation and checkpointing
83+
evaluation_strategy: steps
84+
eval_steps: 500
85+
save_strategy: steps
86+
save_steps: 500
87+
save_total_limit: 5
88+
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
89+
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
90+
greater_is_better: false
91+
92+
report_to:
93+
- wandb
94+
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Example OlmOCR Training Configuration
2+
3+
# Project metadata
4+
project_name: olmocr-qwen-vl-training
5+
run_name: qwen2-vl-7b-finetune-day3-yaml
6+
7+
# Model configuration
8+
model:
9+
name: Qwen/Qwen2-VL-7B-Instruct
10+
trust_remote_code: true
11+
torch_dtype: bfloat16
12+
use_flash_attention: true
13+
attn_implementation: flash_attention_2
14+
15+
# LoRA settings (disabled by default)
16+
use_lora: false
17+
# lora_rank: 8
18+
# lora_alpha: 32
19+
# lora_dropout: 0.1
20+
# lora_target_modules:
21+
# - q_proj
22+
# - v_proj
23+
# - k_proj
24+
# - o_proj
25+
26+
# Dataset configuration
27+
dataset:
28+
29+
train:
30+
- name: processed_01_books_train_iabooks
31+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_train_iabooks/
32+
pipeline: &basic_pipeline
33+
- name: FrontMatterParser
34+
front_matter_class: PageResponse
35+
- name: PDFRenderer
36+
target_longest_image_dim: 1280
37+
- name: StaticLengthDocumentAnchoring
38+
target_anchor_text_len: -1
39+
- name: FinetuningPrompt
40+
- name: FrontMatterOutputFormat
41+
- name: InstructUserMessages
42+
- name: Tokenizer
43+
masking_index: -100
44+
end_of_message_token: "<|im_end|>"
45+
- name: processed_00_documents_train_s2pdf
46+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_train_s2pdf/
47+
pipeline: *basic_pipeline
48+
49+
eval:
50+
- name: processed_00_documents_eval_s2pdf
51+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_00_documents_eval_s2pdf/
52+
pipeline: *basic_pipeline
53+
- name: processed_01_books_eval_iabooks
54+
root_dir: /weka/oe-data-default/jakep/olmOCR-mix-0225/processed_01_books_eval_iabooks/
55+
pipeline: *basic_pipeline
56+
57+
58+
59+
# Training configuration
60+
training:
61+
output_dir: /weka/oe-data-default/jakep/olmocr-trainer/
62+
num_train_epochs: 1
63+
64+
# Batch size and accumulation
65+
per_device_train_batch_size: 1
66+
per_device_eval_batch_size: 1
67+
gradient_accumulation_steps: 32
68+
69+
gradient_checkpointing: False
70+
71+
collator_max_token_len: 8192
72+
73+
# Learning rate
74+
learning_rate: 2e-5
75+
lr_scheduler_type: linear
76+
warmup_ratio: 0.1
77+
78+
# Optimization
79+
optim: adamw_torch
80+
weight_decay: 0.01
81+
max_grad_norm: 1.0
82+
83+
84+
# Evaluation and checkpointing
85+
evaluation_strategy: steps
86+
eval_steps: 500
87+
save_strategy: steps
88+
save_steps: 500
89+
save_total_limit: 5
90+
load_best_model_at_end: false # Needs to be false because it has a problem restoring checkpoints for some reason
91+
metric_for_best_model: eval_processed_00_documents_eval_s2pdf_loss
92+
greater_is_better: false
93+
94+
report_to:
95+
- wandb
96+

olmocr/train/dataloader.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import base64
22
import logging
3+
import re
34
from abc import ABC, abstractmethod
45
from concurrent.futures import ProcessPoolExecutor, as_completed
56
from dataclasses import dataclass, fields
@@ -358,6 +359,49 @@ def __call__(self, sample: Sample) -> Sample:
358359

359360
return sample
360361

362+
@dataclass(frozen=True, slots=True)
363+
class LatexBracketNormalizer(PipelineStep):
364+
"""Normalizes LaTeX brackets in natural text field."""
365+
366+
def __call__(self, sample: Sample) -> Sample:
367+
"""Normalize LaTeX brackets in the natural text field."""
368+
# Get the page_data object
369+
if "page_data" not in sample:
370+
return sample
371+
372+
page_data = sample["page_data"]
373+
if not hasattr(page_data, "natural_text") or not page_data.natural_text:
374+
return sample
375+
376+
text = page_data.natural_text
377+
378+
# Define patterns for LaTeX normalization
379+
# Order matters: process display math first, then inline
380+
patterns = [
381+
(r"\$\$(.+?)\$\$", r"\[\1\]"), # $$...$$ to \[...\]
382+
(r"\$(.+?)\$", r"\(\1\)"), # $...$ to \(...\)
383+
]
384+
385+
# Apply replacements
386+
for pattern, replacement in patterns:
387+
text = re.sub(pattern, replacement, text, flags=re.DOTALL)
388+
389+
# Update the page_data with normalized text
390+
# Since PageResponse is frozen, we need to create a new instance
391+
from olmocr.prompts.prompts import PageResponse
392+
new_page_data = PageResponse(
393+
primary_language=page_data.primary_language,
394+
is_rotation_valid=page_data.is_rotation_valid,
395+
rotation_correction=page_data.rotation_correction,
396+
is_table=page_data.is_table,
397+
is_diagram=page_data.is_diagram,
398+
natural_text=text
399+
)
400+
401+
sample["page_data"] = new_page_data
402+
return sample
403+
404+
361405
@dataclass(frozen=True, slots=True)
362406
class InstructUserMessages(PipelineStep):
363407
"""Creates instruction-following messages format for training."""

0 commit comments

Comments
 (0)