Skip to content

Commit 98df1d5

Browse files
committed
Adding max length option
1 parent abdc907 commit 98df1d5

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

olmocr/train/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ class TrainingConfig:
185185
# Performance
186186
dataloader_drop_last: bool = True
187187
dataloader_num_workers: int = 4
188+
189+
# Data collator settings
190+
collator_max_token_len: Optional[int] = None
188191
remove_unused_columns: bool = False # Important for custom datasets
189192

190193
# Early stopping

olmocr/train/train.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
TrainingArguments,
1919
)
2020

21+
from typing import Optional
2122
from olmocr.train.config import Config
2223
from olmocr.train.dataloader import BaseMarkdownPDFDataset
2324

@@ -33,18 +34,29 @@
3334
class QwenDataCollator:
3435
"""Data collator for vision-language models that handles numpy arrays."""
3536

37+
def __init__(self, max_token_len: Optional[int] = None):
38+
self.max_token_len = max_token_len
39+
3640
def __call__(self, examples):
3741
# Filter out None values and extract the fields we need
3842
batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
3943

4044
for example in examples:
4145
if example is not None:
4246
# Convert numpy arrays to tensors
43-
batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"])
44-
batch["attention_mask"].append(
45-
torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
46-
)
47-
batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"])
47+
input_ids = torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"]
48+
attention_mask = torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
49+
labels = torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"]
50+
51+
# Trim to max_token_len if specified
52+
if self.max_token_len is not None:
53+
input_ids = input_ids[:self.max_token_len]
54+
attention_mask = attention_mask[:self.max_token_len]
55+
labels = labels[:self.max_token_len]
56+
57+
batch["input_ids"].append(input_ids)
58+
batch["attention_mask"].append(attention_mask)
59+
batch["labels"].append(labels)
4860

4961
# Handle pixel_values which might be numpy array or already a tensor
5062
pixel_values = example["pixel_values"]
@@ -236,7 +248,7 @@ def main():
236248
args=training_args,
237249
train_dataset=train_dataset,
238250
eval_dataset=eval_datasets,
239-
data_collator=QwenDataCollator(),
251+
data_collator=QwenDataCollator(max_token_len=config.training.collator_max_token_len),
240252
callbacks=callbacks,
241253
)
242254

0 commit comments

Comments
 (0)