@@ -292,6 +292,7 @@ class Tokenizer(PipelineStep):
292
292
"""Tokenizes messages and creates training labels with proper masking."""
293
293
processor : Any # The model processor (e.g., AutoProcessor)
294
294
masking_index : int = - 100
295
+ end_of_message_token : str = "<|im_end|>" # Configurable, defaults to Qwen format
295
296
296
297
def __call__ (self , sample : Sample ) -> Sample :
297
298
"""Tokenize messages and create labels for training."""
@@ -323,17 +324,17 @@ def __call__(self, sample: Sample) -> Sample:
323
324
# Get labels by tokenizing the output text
324
325
labels = self .processor (text = [response ], padding = True , return_tensors = "np" )
325
326
326
- # Append <|im_end|>\n to the labels
327
- im_end_tokens = self .processor .tokenizer ("<|im_end|> \n " , add_special_tokens = False )["input_ids" ]
328
- im_end_tokens = np .array (im_end_tokens , dtype = inputs .input_ids .dtype )
327
+ # Append end-of-message token to the labels
328
+ end_tokens = self .processor .tokenizer (self . end_of_message_token , add_special_tokens = False )["input_ids" ]
329
+ end_tokens = np .array (end_tokens , dtype = inputs .input_ids .dtype )
329
330
330
331
# Handle the case where labels['input_ids'] is empty
331
332
if labels ["input_ids" ].shape [1 ] == 0 :
332
333
labels_input_ids_0 = np .array ([], dtype = inputs .input_ids .dtype )
333
334
else :
334
335
labels_input_ids_0 = labels ["input_ids" ][0 ].astype (inputs .input_ids .dtype )
335
336
336
- labels ["input_ids" ] = np .concatenate ([labels_input_ids_0 , im_end_tokens ])
337
+ labels ["input_ids" ] = np .concatenate ([labels_input_ids_0 , end_tokens ])
337
338
labels ["input_ids" ] = np .expand_dims (labels ["input_ids" ], axis = 0 )
338
339
339
340
# Concatenate input_ids and labels
@@ -519,6 +520,29 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
519
520
520
521
print (f"[{ i :4d} ] { token_repr :20s} | { str (label ):6s} | { token_id :6d} { marker } " )
521
522
523
+ # Calculate and show token statistics after the table
524
+ print (f"\n Token statistics:" )
525
+
526
+ # Count consecutive high-value tokens that represent the image
527
+ # Qwen uses tokens like 151859, 151860, etc. for image patches
528
+ image_token_threshold = 151000 # Typical threshold for Qwen image tokens
529
+ image_token_count = np .sum (input_ids > image_token_threshold )
530
+
531
+ # Calculate prompt tokens (everything masked)
532
+ prompt_token_count = masked_count
533
+
534
+ # Calculate output tokens (everything not masked)
535
+ output_token_count = total_count - masked_count
536
+
537
+ # Calculate non-image prompt tokens
538
+ non_image_prompt_tokens = prompt_token_count - image_token_count
539
+
540
+ print (f" Image tokens: { image_token_count } " )
541
+ print (f" Prompt tokens (total): { prompt_token_count } " )
542
+ print (f" Prompt tokens (non-image): { non_image_prompt_tokens } " )
543
+ print (f" Output tokens: { output_token_count } " )
544
+ print (f" Total sequence length: { total_count } " )
545
+
522
546
except ImportError as e :
523
547
print (f"\n Could not import transformers: { e } " )
524
548
print ("Install with: pip install transformers" )
0 commit comments