4
4
5
5
import argparse
6
6
import logging
7
+ import numpy as np
7
8
8
9
from transformers import (
9
10
AutoProcessor ,
28
29
logger = logging .getLogger (__name__ )
29
30
30
31
31
- def create_data_collator ():
32
- """Create a data collator for vision-language models."""
33
- def collate_fn (examples ):
32
+ class QwenDataCollator :
33
+ """Data collator for vision-language models that handles numpy arrays."""
34
+
35
+ def __call__ (self , examples ):
34
36
# Filter out None values and extract the fields we need
35
37
batch = {
36
38
'input_ids' : [],
@@ -42,11 +44,22 @@ def collate_fn(examples):
42
44
43
45
for example in examples :
44
46
if example is not None :
45
- batch ['input_ids' ].append (example ['input_ids' ])
46
- batch ['attention_mask' ].append (example ['attention_mask' ])
47
- batch ['labels' ].append (example ['labels' ])
48
- batch ['pixel_values' ].append (example ['pixel_values' ])
49
- batch ['image_grid_thw' ].append (example ['image_grid_thw' ])
47
+ # Convert numpy arrays to tensors
48
+ batch ['input_ids' ].append (torch .from_numpy (example ['input_ids' ]) if isinstance (example ['input_ids' ], np .ndarray ) else example ['input_ids' ])
49
+ batch ['attention_mask' ].append (torch .from_numpy (example ['attention_mask' ]) if isinstance (example ['attention_mask' ], np .ndarray ) else example ['attention_mask' ])
50
+ batch ['labels' ].append (torch .from_numpy (example ['labels' ]) if isinstance (example ['labels' ], np .ndarray ) else example ['labels' ])
51
+
52
+ # Handle pixel_values which might be numpy array or already a tensor
53
+ pixel_values = example ['pixel_values' ]
54
+ if isinstance (pixel_values , np .ndarray ):
55
+ pixel_values = torch .from_numpy (pixel_values )
56
+ batch ['pixel_values' ].append (pixel_values )
57
+
58
+ # Handle image_grid_thw
59
+ image_grid_thw = example ['image_grid_thw' ]
60
+ if isinstance (image_grid_thw , np .ndarray ):
61
+ image_grid_thw = torch .from_numpy (image_grid_thw )
62
+ batch ['image_grid_thw' ].append (image_grid_thw )
50
63
51
64
# Convert lists to tensors with proper padding
52
65
# Note: For Qwen2-VL, we typically handle variable length sequences
@@ -58,8 +71,6 @@ def collate_fn(examples):
58
71
'pixel_values' : batch ['pixel_values' ], # Keep as list for now
59
72
'image_grid_thw' : torch .stack (batch ['image_grid_thw' ])
60
73
}
61
-
62
- return collate_fn
63
74
64
75
65
76
def main ():
@@ -215,7 +226,7 @@ def main():
215
226
args = training_args ,
216
227
train_dataset = train_dataset ,
217
228
eval_dataset = eval_datasets ,
218
- data_collator = create_data_collator (),
229
+ data_collator = QwenDataCollator (),
219
230
callbacks = callbacks ,
220
231
)
221
232
0 commit comments