Skip to content

Commit cfe9aa1

Browse files
committed
Ok, dataloader from start to finish is running, now to write a trainer
1 parent 105d590 commit cfe9aa1

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

olmocr/train/dataloader.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,17 @@
1212
from tqdm import tqdm
1313
from dataclasses import dataclass, fields
1414
from abc import ABC, abstractmethod
15+
import numpy as np
1516

1617
from olmocr.data.renderpdf import render_pdf_to_base64png
1718
from olmocr.prompts.prompts import PageResponse, build_finetuning_prompt
1819
from olmocr.prompts.anchor import get_anchor_text
1920

21+
try:
22+
import numpy as np
23+
except ImportError:
24+
np = None
25+
2026
# Type alias for samples
2127
Sample: TypeAlias = Dict[str, Any]
2228

@@ -299,6 +305,67 @@ def __call__(self, sample: Sample) -> Sample:
299305
return sample
300306

301307

308+
@dataclass(frozen=True, slots=True)
309+
class Tokenizer(PipelineStep):
310+
"""Tokenizes messages and creates training labels with proper masking."""
311+
processor: Any # The model processor (e.g., AutoProcessor)
312+
masking_index: int = -100
313+
314+
def __call__(self, sample: Sample) -> Sample:
315+
"""Tokenize messages and create labels for training."""
316+
if np is None:
317+
raise ImportError("numpy is required for Tokenizer step")
318+
319+
messages = sample["messages"]
320+
main_image = sample["image"]
321+
322+
# Apply chat template to full conversation
323+
text = self.processor.apply_chat_template(
324+
messages,
325+
tokenize=False,
326+
add_generation_prompt=False # Don't add prompt since we have the response
327+
)
328+
329+
# Process everything together
330+
inputs = self.processor(
331+
text=[text],
332+
images=[main_image],
333+
padding=True,
334+
return_tensors="np",
335+
)
336+
337+
# Create labels by copying input_ids and masking the prompt portion
338+
labels = inputs.input_ids.copy()
339+
340+
# Find where the assistant response starts
341+
# This assumes the processor adds some delimiter between user and assistant
342+
# You might need to adjust based on your specific chat template
343+
344+
assistant_token = self.processor.tokenizer.encode("assistant", add_special_tokens=False)[0]
345+
assistant_start_idx = np.where(inputs.input_ids[0] == assistant_token)[0]
346+
347+
if len(assistant_start_idx) > 0:
348+
# Mask everything before the assistant's actual response content
349+
# Usually there's a few tokens after "assistant" role marker
350+
response_start = assistant_start_idx[-1] + 2 # Adjust offset as needed
351+
labels[0, :response_start] = self.masking_index
352+
else:
353+
raise Exception("Could not find assistant tokens")
354+
355+
# Add tokenized data to sample
356+
sample["input_ids"] = inputs.input_ids[0]
357+
sample["attention_mask"] = inputs.attention_mask[0]
358+
sample["labels"] = labels[0]
359+
360+
# Add image-related tensors if present
361+
if hasattr(inputs, 'pixel_values'):
362+
sample["pixel_values"] = inputs.pixel_values
363+
if hasattr(inputs, 'image_grid_thw'):
364+
sample["image_grid_thw"] = inputs.image_grid_thw[0]
365+
366+
return sample
367+
368+
302369
class MarkdownPDFDocumentDataset(BaseMarkdownPDFDataset):
303370
"""Dataset that includes front matter parsing and PDF rendering by default."""
304371

@@ -326,6 +393,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
326393
super().__init__(root_dir, pipeline_steps)
327394

328395

396+
329397
if __name__ == "__main__":
330398
import argparse
331399
from pathlib import Path
@@ -399,5 +467,94 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
399467
print(f"PDF: {Path(first_sample['pdf_path']).name}")
400468
print(f"Image size: {first_sample['image'].size}")
401469
print(f"Page data: {first_sample['page_data']}")
470+
471+
# Test with actual Qwen2.5-VL tokenization
472+
print("\n\n=== Testing with Qwen2.5-VL-7B-Instruct Tokenization ===")
473+
474+
try:
475+
from transformers import AutoProcessor
476+
477+
print("Loading Qwen2.5-VL processor...")
478+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
479+
480+
# Create pipeline with real tokenizer
481+
tokenized_dataset = BaseMarkdownPDFDataset(
482+
args.root_dir,
483+
pipeline_steps=[
484+
FrontMatterParser(front_matter_class=PageResponse),
485+
PDFRenderer(target_longest_image_dim=512),
486+
StaticLengthDocumentAnchoring(target_anchor_text_len=1000),
487+
FinetuningPrompt(),
488+
FrontMatterOutputFormat(),
489+
InstructMessages(),
490+
Tokenizer(processor),
491+
]
492+
)
493+
494+
if len(tokenized_dataset) > 0:
495+
print("\nProcessing first sample with Qwen2.5-VL...")
496+
tokenized_sample = tokenized_dataset[0]
497+
498+
print("\nTokenized output:")
499+
print(f" Keys: {list(tokenized_sample.keys())}")
500+
print(f" Input IDs shape: {tokenized_sample['input_ids'].shape}")
501+
print(f" Labels shape: {tokenized_sample['labels'].shape}")
502+
print(f" Attention mask shape: {tokenized_sample['attention_mask'].shape}")
503+
504+
if 'pixel_values' in tokenized_sample:
505+
print(f" Pixel values shape: {tokenized_sample['pixel_values'].shape}")
506+
if 'image_grid_thw' in tokenized_sample:
507+
print(f" Image grid THW: {tokenized_sample['image_grid_thw']}")
508+
509+
# Show label masking
510+
print(f"\nLabel masking analysis:")
511+
labels = tokenized_sample['labels']
512+
masked_count = np.sum(labels == -100)
513+
total_count = len(labels)
514+
print(f" Total tokens: {total_count}")
515+
print(f" Masked tokens: {masked_count} ({masked_count/total_count*100:.1f}%)")
516+
print(f" Unmasked tokens: {total_count - masked_count} ({(total_count - masked_count)/total_count*100:.1f}%)")
517+
518+
# Find the transition point
519+
transition_idx = None
520+
for i in range(len(labels) - 1):
521+
if labels[i] == -100 and labels[i + 1] != -100:
522+
transition_idx = i + 1
523+
break
524+
525+
if transition_idx:
526+
print(f" Transition from masked to unmasked at position: {transition_idx}")
527+
528+
# Print all tokens
529+
input_ids = tokenized_sample['input_ids']
530+
print(f"\nAll tokens ({len(input_ids)} total):")
531+
print("Format: [index] Token (repr) | Label | Token ID")
532+
print("-" * 80)
533+
534+
for i in range(len(input_ids)):
535+
token = processor.tokenizer.decode([input_ids[i]])
536+
token_repr = repr(token)
537+
label = labels[i] if i < len(labels) else "N/A"
538+
token_id = input_ids[i]
539+
540+
# Mark special positions
541+
marker = ""
542+
if transition_idx and i == transition_idx:
543+
marker = " <-- TRANSITION (first unmasked)"
544+
elif i == 0:
545+
marker = " <-- START"
546+
elif label != -100 and i > 0 and labels[i-1] == -100:
547+
marker = " <-- response begins"
548+
549+
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
550+
551+
except ImportError as e:
552+
print(f"\nCould not import transformers: {e}")
553+
print("Install with: pip install transformers")
554+
except Exception as e:
555+
print(f"\nError during tokenization test: {e}")
556+
import traceback
557+
traceback.print_exc()
558+
402559
else:
403560
raise AssertionError("Expected some data to be created at this point")

0 commit comments

Comments
 (0)