|
18 | 18 | TrainingArguments,
|
19 | 19 | )
|
20 | 20 |
|
| 21 | +from typing import Optional |
21 | 22 | from olmocr.train.config import Config
|
22 | 23 | from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
23 | 24 |
|
|
33 | 34 | class QwenDataCollator:
|
34 | 35 | """Data collator for vision-language models that handles numpy arrays."""
|
35 | 36 |
|
| 37 | + def __init__(self, max_token_len: Optional[int] = None): |
| 38 | + self.max_token_len = max_token_len |
| 39 | + |
36 | 40 | def __call__(self, examples):
|
37 | 41 | # Filter out None values and extract the fields we need
|
38 | 42 | batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
|
39 | 43 |
|
40 | 44 | for example in examples:
|
41 | 45 | if example is not None:
|
42 | 46 | # Convert numpy arrays to tensors
|
43 |
| - batch["input_ids"].append(torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"]) |
44 |
| - batch["attention_mask"].append( |
45 |
| - torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"] |
46 |
| - ) |
47 |
| - batch["labels"].append(torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"]) |
| 47 | + input_ids = torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"] |
| 48 | + attention_mask = torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"] |
| 49 | + labels = torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"] |
| 50 | + |
| 51 | + # Trim to max_token_len if specified |
| 52 | + if self.max_token_len is not None: |
| 53 | + input_ids = input_ids[:self.max_token_len] |
| 54 | + attention_mask = attention_mask[:self.max_token_len] |
| 55 | + labels = labels[:self.max_token_len] |
| 56 | + |
| 57 | + batch["input_ids"].append(input_ids) |
| 58 | + batch["attention_mask"].append(attention_mask) |
| 59 | + batch["labels"].append(labels) |
48 | 60 |
|
49 | 61 | # Handle pixel_values which might be numpy array or already a tensor
|
50 | 62 | pixel_values = example["pixel_values"]
|
@@ -236,7 +248,7 @@ def main():
|
236 | 248 | args=training_args,
|
237 | 249 | train_dataset=train_dataset,
|
238 | 250 | eval_dataset=eval_datasets,
|
239 |
| - data_collator=QwenDataCollator(), |
| 251 | + data_collator=QwenDataCollator(max_token_len=config.training.collator_max_token_len), |
240 | 252 | callbacks=callbacks,
|
241 | 253 | )
|
242 | 254 |
|
|
0 commit comments