Skip to content

Commit c36b5df

Browse files
committed
Cleanup collator
1 parent 887190e commit c36b5df

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

olmocr/train/train.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import argparse
66
import logging
7+
import numpy as np
78

89
from transformers import (
910
AutoProcessor,
@@ -28,9 +29,10 @@
2829
logger = logging.getLogger(__name__)
2930

3031

31-
def create_data_collator():
32-
"""Create a data collator for vision-language models."""
33-
def collate_fn(examples):
32+
class QwenDataCollator:
33+
"""Data collator for vision-language models that handles numpy arrays."""
34+
35+
def __call__(self, examples):
3436
# Filter out None values and extract the fields we need
3537
batch = {
3638
'input_ids': [],
@@ -42,11 +44,22 @@ def collate_fn(examples):
4244

4345
for example in examples:
4446
if example is not None:
45-
batch['input_ids'].append(example['input_ids'])
46-
batch['attention_mask'].append(example['attention_mask'])
47-
batch['labels'].append(example['labels'])
48-
batch['pixel_values'].append(example['pixel_values'])
49-
batch['image_grid_thw'].append(example['image_grid_thw'])
47+
# Convert numpy arrays to tensors
48+
batch['input_ids'].append(torch.from_numpy(example['input_ids']) if isinstance(example['input_ids'], np.ndarray) else example['input_ids'])
49+
batch['attention_mask'].append(torch.from_numpy(example['attention_mask']) if isinstance(example['attention_mask'], np.ndarray) else example['attention_mask'])
50+
batch['labels'].append(torch.from_numpy(example['labels']) if isinstance(example['labels'], np.ndarray) else example['labels'])
51+
52+
# Handle pixel_values which might be numpy array or already a tensor
53+
pixel_values = example['pixel_values']
54+
if isinstance(pixel_values, np.ndarray):
55+
pixel_values = torch.from_numpy(pixel_values)
56+
batch['pixel_values'].append(pixel_values)
57+
58+
# Handle image_grid_thw
59+
image_grid_thw = example['image_grid_thw']
60+
if isinstance(image_grid_thw, np.ndarray):
61+
image_grid_thw = torch.from_numpy(image_grid_thw)
62+
batch['image_grid_thw'].append(image_grid_thw)
5063

5164
# Convert lists to tensors with proper padding
5265
# Note: For Qwen2-VL, we typically handle variable length sequences
@@ -58,8 +71,6 @@ def collate_fn(examples):
5871
'pixel_values': batch['pixel_values'], # Keep as list for now
5972
'image_grid_thw': torch.stack(batch['image_grid_thw'])
6073
}
61-
62-
return collate_fn
6374

6475

6576
def main():
@@ -215,7 +226,7 @@ def main():
215226
args=training_args,
216227
train_dataset=train_dataset,
217228
eval_dataset=eval_datasets,
218-
data_collator=create_data_collator(),
229+
data_collator=QwenDataCollator(),
219230
callbacks=callbacks,
220231
)
221232

0 commit comments

Comments
 (0)