Skip to content

Commit c93ac4a

Browse files
committed
Cleaned up loader
1 parent 6033881 commit c93ac4a

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

olmocr/train/dataloader.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class Tokenizer(PipelineStep):
292292
"""Tokenizes messages and creates training labels with proper masking."""
293293
processor: Any # The model processor (e.g., AutoProcessor)
294294
masking_index: int = -100
295+
end_of_message_token: str = "<|im_end|>" # Configurable, defaults to Qwen format
295296

296297
def __call__(self, sample: Sample) -> Sample:
297298
"""Tokenize messages and create labels for training."""
@@ -323,17 +324,17 @@ def __call__(self, sample: Sample) -> Sample:
323324
# Get labels by tokenizing the output text
324325
labels = self.processor(text=[response], padding=True, return_tensors="np")
325326

326-
# Append <|im_end|>\n to the labels
327-
im_end_tokens = self.processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
328-
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype)
327+
# Append end-of-message token to the labels
328+
end_tokens = self.processor.tokenizer(self.end_of_message_token, add_special_tokens=False)["input_ids"]
329+
end_tokens = np.array(end_tokens, dtype=inputs.input_ids.dtype)
329330

330331
# Handle the case where labels['input_ids'] is empty
331332
if labels["input_ids"].shape[1] == 0:
332333
labels_input_ids_0 = np.array([], dtype=inputs.input_ids.dtype)
333334
else:
334335
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
335336

336-
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
337+
labels["input_ids"] = np.concatenate([labels_input_ids_0, end_tokens])
337338
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
338339

339340
# Concatenate input_ids and labels
@@ -519,6 +520,29 @@ def __init__(self, root_dir: str | PathLike, target_longest_image_dim: int, imag
519520

520521
print(f"[{i:4d}] {token_repr:20s} | {str(label):6s} | {token_id:6d}{marker}")
521522

523+
# Calculate and show token statistics after the table
524+
print(f"\nToken statistics:")
525+
526+
# Count consecutive high-value tokens that represent the image
527+
# Qwen uses tokens like 151859, 151860, etc. for image patches
528+
image_token_threshold = 151000 # Typical threshold for Qwen image tokens
529+
image_token_count = np.sum(input_ids > image_token_threshold)
530+
531+
# Calculate prompt tokens (everything masked)
532+
prompt_token_count = masked_count
533+
534+
# Calculate output tokens (everything not masked)
535+
output_token_count = total_count - masked_count
536+
537+
# Calculate non-image prompt tokens
538+
non_image_prompt_tokens = prompt_token_count - image_token_count
539+
540+
print(f" Image tokens: {image_token_count}")
541+
print(f" Prompt tokens (total): {prompt_token_count}")
542+
print(f" Prompt tokens (non-image): {non_image_prompt_tokens}")
543+
print(f" Output tokens: {output_token_count}")
544+
print(f" Total sequence length: {total_count}")
545+
522546
except ImportError as e:
523547
print(f"\nCould not import transformers: {e}")
524548
print("Install with: pip install transformers")

0 commit comments

Comments
 (0)