31
31
32
32
33
33
class QwenDataCollator :
34
- """Data collator for vision-language models that handles numpy arrays."""
34
+ """Data collator for vision-language models that handles numpy arrays and variable-length sequences."""
35
+
36
+ def __init__ (self , pad_token_id = 0 ):
37
+ self .pad_token_id = pad_token_id
35
38
36
39
def __call__ (self , examples ):
37
40
# Filter out None values and extract the fields we need
@@ -58,15 +61,43 @@ def __call__(self, examples):
58
61
image_grid_thw = torch .from_numpy (image_grid_thw )
59
62
batch ["image_grid_thw" ].append (image_grid_thw )
60
63
61
- # Convert lists to tensors with proper padding
62
- # Note: For Qwen2-VL, we typically handle variable length sequences
63
- # The model's processor should handle the padding internally
64
+ # Find the maximum sequence length in the batch
65
+ max_length = max (ids .shape [0 ] for ids in batch ["input_ids" ])
66
+
67
+ # Pad sequences to the maximum length
68
+ padded_input_ids = []
69
+ padded_attention_mask = []
70
+ padded_labels = []
71
+
72
+ for i in range (len (batch ["input_ids" ])):
73
+ input_ids = batch ["input_ids" ][i ]
74
+ attention_mask = batch ["attention_mask" ][i ]
75
+ labels = batch ["labels" ][i ]
76
+
77
+ # Calculate padding needed
78
+ padding_length = max_length - input_ids .shape [0 ]
79
+
80
+ if padding_length > 0 :
81
+ # Pad input_ids with pad_token_id
82
+ input_ids = torch .cat ([input_ids , torch .full ((padding_length ,), self .pad_token_id , dtype = input_ids .dtype )])
83
+
84
+ # Pad attention_mask with zeros (indicating padded positions)
85
+ attention_mask = torch .cat ([attention_mask , torch .zeros (padding_length , dtype = attention_mask .dtype )])
86
+
87
+ # Pad labels with -100 (ignored in loss computation)
88
+ labels = torch .cat ([labels , torch .full ((padding_length ,), - 100 , dtype = labels .dtype )])
89
+
90
+ padded_input_ids .append (input_ids )
91
+ padded_attention_mask .append (attention_mask )
92
+ padded_labels .append (labels )
93
+
94
+ # Stack all sequences now that they have the same length
64
95
return {
65
- "input_ids" : torch .stack (batch [ "input_ids" ] ),
66
- "attention_mask" : torch .stack (batch [ "attention_mask" ] ),
67
- "labels" : torch .stack (batch [ "labels" ] ),
68
- "pixel_values" : torch .stack (batch ["pixel_values" ]), # Stack into tensor
69
- "image_grid_thw" : torch .stack (batch ["image_grid_thw" ]),
96
+ "input_ids" : torch .stack (padded_input_ids ),
97
+ "attention_mask" : torch .stack (padded_attention_mask ),
98
+ "labels" : torch .stack (padded_labels ),
99
+ "pixel_values" : torch .stack (batch ["pixel_values" ]), # Assuming these are already same size
100
+ "image_grid_thw" : torch .stack (batch ["image_grid_thw" ]), # Assuming these are already same size
70
101
}
71
102
72
103
@@ -200,12 +231,13 @@ def main():
200
231
data_seed = config .training .data_seed ,
201
232
push_to_hub = False ,
202
233
resume_from_checkpoint = config .training .resume_from_checkpoint ,
203
- deepspeed = config .training .deepspeed ,
204
234
dataloader_drop_last = config .training .dataloader_drop_last ,
205
235
dataloader_num_workers = config .training .dataloader_num_workers ,
206
236
remove_unused_columns = config .training .remove_unused_columns ,
207
237
eval_on_start = True ,
208
238
run_name = config .run_name ,
239
+ torch_compile = True ,
240
+ torch_compile_backend = "inductor"
209
241
)
210
242
211
243
# Set up callbacks
@@ -224,7 +256,7 @@ def main():
224
256
args = training_args ,
225
257
train_dataset = train_dataset ,
226
258
eval_dataset = eval_datasets ,
227
- data_collator = QwenDataCollator (),
259
+ data_collator = QwenDataCollator (pad_token_id = processor . tokenizer . pad_token_id or 0 ),
228
260
callbacks = callbacks ,
229
261
)
230
262
0 commit comments