Skip to content

Commit 5a4a836

Browse files
committed
Calibration
1 parent 9115a02 commit 5a4a836

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212
destination_path: Where to save compressed checkpoint (local or S3)
1313
recipe_path: Path to quantization config YAML file
1414
num_calibration_samples: Number of calibration samples to use (default: 100, set to 0 to disable)
15-
calibration_pdfs: '+' separated list of PDF paths to use for calibration (required when num_calibration_samples > 0)
15+
calibration_pdfs: Glob pattern for PDF paths to use for calibration (required when num_calibration_samples > 0)
1616
"""
1717

1818
import argparse
1919
import asyncio
2020
import base64
21+
import glob
2122
import json
2223
import os
2324
import random
2425
import shutil
2526
import tempfile
2627
from io import BytesIO
28+
from pathlib import Path
2729
from typing import Optional, Tuple, Union, List
2830

2931
import boto3
@@ -354,11 +356,11 @@ def main():
354356
# Local to S3
355357
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed --recipe train/quantization_configs/qwen2vl_w8a8_fp8.yaml
356358
357-
# Using local calibration PDFs (with glob)
358-
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/pdfs/doc1.pdf+/data/pdfs/doc2.pdf+/data/pdfs/doc3.pdf"
359+
# Using glob pattern for calibration PDFs
360+
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/pdfs/*.pdf"
359361
360-
# Using glob pattern with shell expansion
361-
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "$(ls /data/pdfs/*.pdf | tr '\n' '+')"
362+
# Using recursive glob pattern
363+
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/**/*.pdf"
362364
"""
363365
)
364366
parser.add_argument("source", help="Source checkpoint path (local or S3)")
@@ -367,15 +369,35 @@ def main():
367369
parser.add_argument("--num-calibration-samples", type=int, default=100,
368370
help="Number of calibration samples to use (default: 100, set to 0 to disable)")
369371
parser.add_argument("--calibration-pdfs", type=str, default=None,
370-
help="'+' separated list of calibration PDF paths (e.g., '/path/to/pdf1.pdf+/path/to/pdf2.pdf'). Required when num-calibration-samples > 0.")
372+
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.")
371373

372374
args = parser.parse_args()
373375

374376
# Parse calibration PDFs if provided
375377
calibration_pdfs = None
376378
if args.calibration_pdfs:
377-
calibration_pdfs = args.calibration_pdfs.split('+')
378-
print(f"Parsed {len(calibration_pdfs)} calibration PDF paths")
379+
# Use pathlib for better glob handling
380+
pattern = args.calibration_pdfs
381+
382+
# Handle both absolute and relative paths with recursive glob
383+
if '**' in pattern:
384+
# For recursive patterns, we need to handle them specially
385+
if pattern.startswith('/'):
386+
# Absolute path with **
387+
parts = pattern.split('**')
388+
base_path = Path(parts[0])
389+
glob_pattern = '**' + parts[1] if len(parts) > 1 else '**/*.pdf'
390+
calibration_pdfs = list(base_path.glob(glob_pattern.lstrip('/')))
391+
else:
392+
# Relative path with **
393+
calibration_pdfs = list(Path('.').glob(pattern))
394+
calibration_pdfs = [str(p.absolute()) for p in calibration_pdfs if p.suffix.lower() == '.pdf']
395+
else:
396+
# Use standard glob for non-recursive patterns
397+
calibration_pdfs = glob.glob(pattern)
398+
calibration_pdfs = [p for p in calibration_pdfs if p.lower().endswith('.pdf')]
399+
400+
print(f"Found {len(calibration_pdfs)} PDF files matching pattern: {args.calibration_pdfs}")
379401

380402

381403
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples, calibration_pdfs)

0 commit comments

Comments
 (0)