Skip to content

Commit c11120a

Browse files
committed
Trying to do batch size > 1
1 parent 5c2d69a commit c11120a

File tree

3 files changed

+45
-16
lines changed

3 files changed

+45
-16
lines changed

olmocr/train/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,6 @@ class TrainingConfig:
171171
# Resume from checkpoint
172172
resume_from_checkpoint: Optional[str] = None
173173

174-
# DeepSpeed
175-
deepspeed: Optional[str] = None
176-
177174
# Performance
178175
dataloader_drop_last: bool = True
179176
dataloader_num_workers: int = 4

olmocr/train/configs/example_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ training:
6262
num_train_epochs: 1
6363

6464
# Batch size and accumulation
65-
per_device_train_batch_size: 1
66-
per_device_eval_batch_size: 1
65+
per_device_train_batch_size: 2
66+
per_device_eval_batch_size: 8
6767
gradient_accumulation_steps: 8
6868

6969
gradient_checkpointing: False

olmocr/train/train.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131

3232

3333
class QwenDataCollator:
34-
"""Data collator for vision-language models that handles numpy arrays."""
34+
"""Data collator for vision-language models that handles numpy arrays and variable-length sequences."""
35+
36+
def __init__(self, pad_token_id=0):
37+
self.pad_token_id = pad_token_id
3538

3639
def __call__(self, examples):
3740
# Filter out None values and extract the fields we need
@@ -58,15 +61,43 @@ def __call__(self, examples):
5861
image_grid_thw = torch.from_numpy(image_grid_thw)
5962
batch["image_grid_thw"].append(image_grid_thw)
6063

61-
# Convert lists to tensors with proper padding
62-
# Note: For Qwen2-VL, we typically handle variable length sequences
63-
# The model's processor should handle the padding internally
64+
# Find the maximum sequence length in the batch
65+
max_length = max(ids.shape[0] for ids in batch["input_ids"])
66+
67+
# Pad sequences to the maximum length
68+
padded_input_ids = []
69+
padded_attention_mask = []
70+
padded_labels = []
71+
72+
for i in range(len(batch["input_ids"])):
73+
input_ids = batch["input_ids"][i]
74+
attention_mask = batch["attention_mask"][i]
75+
labels = batch["labels"][i]
76+
77+
# Calculate padding needed
78+
padding_length = max_length - input_ids.shape[0]
79+
80+
if padding_length > 0:
81+
# Pad input_ids with pad_token_id
82+
input_ids = torch.cat([input_ids, torch.full((padding_length,), self.pad_token_id, dtype=input_ids.dtype)])
83+
84+
# Pad attention_mask with zeros (indicating padded positions)
85+
attention_mask = torch.cat([attention_mask, torch.zeros(padding_length, dtype=attention_mask.dtype)])
86+
87+
# Pad labels with -100 (ignored in loss computation)
88+
labels = torch.cat([labels, torch.full((padding_length,), -100, dtype=labels.dtype)])
89+
90+
padded_input_ids.append(input_ids)
91+
padded_attention_mask.append(attention_mask)
92+
padded_labels.append(labels)
93+
94+
# Stack all sequences now that they have the same length
6495
return {
65-
"input_ids": torch.stack(batch["input_ids"]),
66-
"attention_mask": torch.stack(batch["attention_mask"]),
67-
"labels": torch.stack(batch["labels"]),
68-
"pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor
69-
"image_grid_thw": torch.stack(batch["image_grid_thw"]),
96+
"input_ids": torch.stack(padded_input_ids),
97+
"attention_mask": torch.stack(padded_attention_mask),
98+
"labels": torch.stack(padded_labels),
99+
"pixel_values": torch.stack(batch["pixel_values"]), # Assuming these are already same size
100+
"image_grid_thw": torch.stack(batch["image_grid_thw"]), # Assuming these are already same size
70101
}
71102

72103

@@ -200,12 +231,13 @@ def main():
200231
data_seed=config.training.data_seed,
201232
push_to_hub=False,
202233
resume_from_checkpoint=config.training.resume_from_checkpoint,
203-
deepspeed=config.training.deepspeed,
204234
dataloader_drop_last=config.training.dataloader_drop_last,
205235
dataloader_num_workers=config.training.dataloader_num_workers,
206236
remove_unused_columns=config.training.remove_unused_columns,
207237
eval_on_start=True,
208238
run_name=config.run_name,
239+
torch_compile=True,
240+
torch_compile_backend="inductor"
209241
)
210242

211243
# Set up callbacks
@@ -224,7 +256,7 @@ def main():
224256
args=training_args,
225257
train_dataset=train_dataset,
226258
eval_dataset=eval_datasets,
227-
data_collator=QwenDataCollator(),
259+
data_collator=QwenDataCollator(pad_token_id=processor.tokenizer.pad_token_id or 0),
228260
callbacks=callbacks,
229261
)
230262

0 commit comments

Comments
 (0)