12
12
from tqdm import tqdm
13
13
from dataclasses import dataclass , fields
14
14
from abc import ABC , abstractmethod
15
+ import numpy as np
15
16
16
17
from olmocr .data .renderpdf import render_pdf_to_base64png
17
18
from olmocr .prompts .prompts import PageResponse , build_finetuning_prompt
18
19
from olmocr .prompts .anchor import get_anchor_text
19
20
21
+ try :
22
+ import numpy as np
23
+ except ImportError :
24
+ np = None
25
+
20
26
# Type alias for samples
21
27
Sample : TypeAlias = Dict [str , Any ]
22
28
@@ -299,6 +305,67 @@ def __call__(self, sample: Sample) -> Sample:
299
305
return sample
300
306
301
307
308
+ @dataclass (frozen = True , slots = True )
309
+ class Tokenizer (PipelineStep ):
310
+ """Tokenizes messages and creates training labels with proper masking."""
311
+ processor : Any # The model processor (e.g., AutoProcessor)
312
+ masking_index : int = - 100
313
+
314
+ def __call__ (self , sample : Sample ) -> Sample :
315
+ """Tokenize messages and create labels for training."""
316
+ if np is None :
317
+ raise ImportError ("numpy is required for Tokenizer step" )
318
+
319
+ messages = sample ["messages" ]
320
+ main_image = sample ["image" ]
321
+
322
+ # Apply chat template to full conversation
323
+ text = self .processor .apply_chat_template (
324
+ messages ,
325
+ tokenize = False ,
326
+ add_generation_prompt = False # Don't add prompt since we have the response
327
+ )
328
+
329
+ # Process everything together
330
+ inputs = self .processor (
331
+ text = [text ],
332
+ images = [main_image ],
333
+ padding = True ,
334
+ return_tensors = "np" ,
335
+ )
336
+
337
+ # Create labels by copying input_ids and masking the prompt portion
338
+ labels = inputs .input_ids .copy ()
339
+
340
+ # Find where the assistant response starts
341
+ # This assumes the processor adds some delimiter between user and assistant
342
+ # You might need to adjust based on your specific chat template
343
+
344
+ assistant_token = self .processor .tokenizer .encode ("assistant" , add_special_tokens = False )[0 ]
345
+ assistant_start_idx = np .where (inputs .input_ids [0 ] == assistant_token )[0 ]
346
+
347
+ if len (assistant_start_idx ) > 0 :
348
+ # Mask everything before the assistant's actual response content
349
+ # Usually there's a few tokens after "assistant" role marker
350
+ response_start = assistant_start_idx [- 1 ] + 2 # Adjust offset as needed
351
+ labels [0 , :response_start ] = self .masking_index
352
+ else :
353
+ raise Exception ("Could not find assistant tokens" )
354
+
355
+ # Add tokenized data to sample
356
+ sample ["input_ids" ] = inputs .input_ids [0 ]
357
+ sample ["attention_mask" ] = inputs .attention_mask [0 ]
358
+ sample ["labels" ] = labels [0 ]
359
+
360
+ # Add image-related tensors if present
361
+ if hasattr (inputs , 'pixel_values' ):
362
+ sample ["pixel_values" ] = inputs .pixel_values
363
+ if hasattr (inputs , 'image_grid_thw' ):
364
+ sample ["image_grid_thw" ] = inputs .image_grid_thw [0 ]
365
+
366
+ return sample
367
+
368
+
302
369
class MarkdownPDFDocumentDataset (BaseMarkdownPDFDataset ):
303
370
"""Dataset that includes front matter parsing and PDF rendering by default."""
304
371
@@ -326,6 +393,7 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
326
393
super ().__init__ (root_dir , pipeline_steps )
327
394
328
395
396
+
329
397
if __name__ == "__main__" :
330
398
import argparse
331
399
from pathlib import Path
@@ -399,5 +467,94 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
399
467
print (f"PDF: { Path (first_sample ['pdf_path' ]).name } " )
400
468
print (f"Image size: { first_sample ['image' ].size } " )
401
469
print (f"Page data: { first_sample ['page_data' ]} " )
470
+
471
+ # Test with actual Qwen2.5-VL tokenization
472
+ print ("\n \n === Testing with Qwen2.5-VL-7B-Instruct Tokenization ===" )
473
+
474
+ try :
475
+ from transformers import AutoProcessor
476
+
477
+ print ("Loading Qwen2.5-VL processor..." )
478
+ processor = AutoProcessor .from_pretrained ("Qwen/Qwen2.5-VL-7B-Instruct" )
479
+
480
+ # Create pipeline with real tokenizer
481
+ tokenized_dataset = BaseMarkdownPDFDataset (
482
+ args .root_dir ,
483
+ pipeline_steps = [
484
+ FrontMatterParser (front_matter_class = PageResponse ),
485
+ PDFRenderer (target_longest_image_dim = 512 ),
486
+ StaticLengthDocumentAnchoring (target_anchor_text_len = 1000 ),
487
+ FinetuningPrompt (),
488
+ FrontMatterOutputFormat (),
489
+ InstructMessages (),
490
+ Tokenizer (processor ),
491
+ ]
492
+ )
493
+
494
+ if len (tokenized_dataset ) > 0 :
495
+ print ("\n Processing first sample with Qwen2.5-VL..." )
496
+ tokenized_sample = tokenized_dataset [0 ]
497
+
498
+ print ("\n Tokenized output:" )
499
+ print (f" Keys: { list (tokenized_sample .keys ())} " )
500
+ print (f" Input IDs shape: { tokenized_sample ['input_ids' ].shape } " )
501
+ print (f" Labels shape: { tokenized_sample ['labels' ].shape } " )
502
+ print (f" Attention mask shape: { tokenized_sample ['attention_mask' ].shape } " )
503
+
504
+ if 'pixel_values' in tokenized_sample :
505
+ print (f" Pixel values shape: { tokenized_sample ['pixel_values' ].shape } " )
506
+ if 'image_grid_thw' in tokenized_sample :
507
+ print (f" Image grid THW: { tokenized_sample ['image_grid_thw' ]} " )
508
+
509
+ # Show label masking
510
+ print (f"\n Label masking analysis:" )
511
+ labels = tokenized_sample ['labels' ]
512
+ masked_count = np .sum (labels == - 100 )
513
+ total_count = len (labels )
514
+ print (f" Total tokens: { total_count } " )
515
+ print (f" Masked tokens: { masked_count } ({ masked_count / total_count * 100 :.1f} %)" )
516
+ print (f" Unmasked tokens: { total_count - masked_count } ({ (total_count - masked_count )/ total_count * 100 :.1f} %)" )
517
+
518
+ # Find the transition point
519
+ transition_idx = None
520
+ for i in range (len (labels ) - 1 ):
521
+ if labels [i ] == - 100 and labels [i + 1 ] != - 100 :
522
+ transition_idx = i + 1
523
+ break
524
+
525
+ if transition_idx :
526
+ print (f" Transition from masked to unmasked at position: { transition_idx } " )
527
+
528
+ # Print all tokens
529
+ input_ids = tokenized_sample ['input_ids' ]
530
+ print (f"\n All tokens ({ len (input_ids )} total):" )
531
+ print ("Format: [index] Token (repr) | Label | Token ID" )
532
+ print ("-" * 80 )
533
+
534
+ for i in range (len (input_ids )):
535
+ token = processor .tokenizer .decode ([input_ids [i ]])
536
+ token_repr = repr (token )
537
+ label = labels [i ] if i < len (labels ) else "N/A"
538
+ token_id = input_ids [i ]
539
+
540
+ # Mark special positions
541
+ marker = ""
542
+ if transition_idx and i == transition_idx :
543
+ marker = " <-- TRANSITION (first unmasked)"
544
+ elif i == 0 :
545
+ marker = " <-- START"
546
+ elif label != - 100 and i > 0 and labels [i - 1 ] == - 100 :
547
+ marker = " <-- response begins"
548
+
549
+ print (f"[{ i :4d} ] { token_repr :20s} | { str (label ):6s} | { token_id :6d} { marker } " )
550
+
551
+ except ImportError as e :
552
+ print (f"\n Could not import transformers: { e } " )
553
+ print ("Install with: pip install transformers" )
554
+ except Exception as e :
555
+ print (f"\n Error during tokenization test: { e } " )
556
+ import traceback
557
+ traceback .print_exc ()
558
+
402
559
else :
403
560
raise AssertionError ("Expected some data to be created at this point" )
0 commit comments