Skip to content

Commit fcd373d

Browse files
committed
Calibration stuff
1 parent 2218bf8 commit fcd373d

File tree

1 file changed

+50
-43
lines changed

1 file changed

+50
-43
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
3. Saves compressed model to destination (local or S3)
77
88
Usage:
9-
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N]
9+
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N] [--calibration-pdfs PDF1+PDF2+...]
1010
1111
source_path: Path to checkpoint (local or S3)
1212
destination_path: Where to save compressed checkpoint (local or S3)
1313
recipe_path: Path to quantization config YAML file
14-
num_calibration_samples: Number of calibration samples to use (default: 100)
14+
num_calibration_samples: Number of calibration samples to use (default: 100, set to 0 to disable)
15+
calibration_pdfs: '+' separated list of PDF paths to use for calibration (required when num_calibration_samples > 0)
1516
"""
1617

1718
import argparse
@@ -36,43 +37,38 @@
3637

3738

3839
s3_client = boto3.client("s3")
39-
CALIBRATION_S3_PATH = "s3://ai2-oe-data/jakep/olmocr/olmOCR-mix-0225/benchmark_set"
4040

4141

42-
def download_calibration_pdfs(num_samples: int) -> List[str]:
43-
"""Download calibration PDFs from S3 and return local paths."""
44-
bucket, prefix = parse_s3_path(CALIBRATION_S3_PATH)
42+
def get_calibration_pdfs(num_samples: int, pdf_paths: List[str]) -> List[str]:
43+
"""Get calibration PDFs from provided paths.
4544
46-
# Create temporary directory for PDFs
47-
temp_dir = tempfile.mkdtemp()
48-
print(f"Downloading calibration PDFs to {temp_dir}...")
49-
50-
# List all PDFs in the calibration dataset
51-
paginator = s3_client.get_paginator("list_objects_v2")
52-
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
45+
Args:
46+
num_samples: Number of samples to use
47+
pdf_paths: List of local PDF paths
48+
49+
Returns:
50+
List of valid PDF paths
51+
"""
52+
print(f"Using {len(pdf_paths)} provided calibration PDFs")
5353

54-
pdf_keys = []
55-
for page in pages:
56-
for obj in page.get("Contents", []):
57-
key = obj["Key"]
58-
if key.endswith(".pdf"):
59-
pdf_keys.append(key)
54+
# If more PDFs provided than needed, randomly sample
55+
if len(pdf_paths) > num_samples:
56+
pdf_paths = random.sample(pdf_paths, num_samples)
57+
print(f"Randomly sampled {num_samples} PDFs from provided paths")
6058

61-
# Randomly sample PDFs
62-
if len(pdf_keys) > num_samples:
63-
pdf_keys = random.sample(pdf_keys, num_samples)
59+
# Verify all PDFs exist
60+
valid_paths = []
61+
for path in pdf_paths:
62+
if os.path.exists(path) and path.endswith('.pdf'):
63+
valid_paths.append(path)
64+
else:
65+
print(f" Warning: Skipping invalid path: {path}")
6466

65-
# Download the PDFs
66-
local_paths = []
67-
for key in pdf_keys:
68-
filename = os.path.basename(key)
69-
local_path = os.path.join(temp_dir, filename)
70-
s3_client.download_file(bucket, key, local_path)
71-
local_paths.append(local_path)
72-
print(f" Downloaded {filename}")
67+
if not valid_paths:
68+
raise ValueError("No valid PDF paths found in the provided list")
7369

74-
print(f"Downloaded {len(local_paths)} calibration PDFs")
75-
return local_paths, temp_dir
70+
print(f"Using {len(valid_paths)} valid calibration PDFs")
71+
return valid_paths
7672

7773

7874
async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[dict]:
@@ -243,7 +239,7 @@ def data_collator(batch):
243239
return {key: torch.tensor(value) for key, value in batch[0].items()}
244240

245241

246-
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100) -> None:
242+
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100, calibration_pdfs: Optional[List[str]] = None) -> None:
247243
"""Compress OlmOCR checkpoint using FP8 quantization."""
248244
# Load model and tokenizer
249245
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
@@ -257,16 +253,18 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
257253

258254
# Prepare calibration dataset if requested
259255
dataset = None
260-
temp_pdf_dir = None
261256

262257
if num_calibration_samples > 0:
258+
if not calibration_pdfs:
259+
raise ValueError("Calibration PDFs must be provided when num_calibration_samples > 0. Use --calibration-pdfs argument.")
260+
263261
print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...")
264262

265263
# Load processor for the model
266264
processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir)
267265

268-
# Download PDFs
269-
pdf_paths, temp_pdf_dir = download_calibration_pdfs(num_calibration_samples)
266+
# Get calibration PDFs from provided paths
267+
pdf_paths = get_calibration_pdfs(num_calibration_samples, calibration_pdfs)
270268

271269
# Prepare dataset
272270
dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor))
@@ -321,11 +319,6 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
321319
print(f"Cleaning up temporary directory {temp_source_dir}...")
322320
shutil.rmtree(temp_source_dir)
323321

324-
# Clean up temporary PDF directory if needed
325-
if temp_pdf_dir:
326-
print(f"Cleaning up temporary PDF directory {temp_pdf_dir}...")
327-
shutil.rmtree(temp_pdf_dir)
328-
329322
# Free up GPU memory
330323
del model
331324
torch.cuda.empty_cache()
@@ -348,18 +341,32 @@ def main():
348341
349342
# Local to S3
350343
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed --recipe train/quantization_configs/qwen2vl_w8a8_fp8.yaml
344+
345+
# Using local calibration PDFs (with glob)
346+
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/pdfs/doc1.pdf+/data/pdfs/doc2.pdf+/data/pdfs/doc3.pdf"
347+
348+
# Using glob pattern with shell expansion
349+
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "$(ls /data/pdfs/*.pdf | tr '\n' '+')"
351350
"""
352351
)
353352
parser.add_argument("source", help="Source checkpoint path (local or S3)")
354353
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
355354
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
356355
parser.add_argument("--num-calibration-samples", type=int, default=100,
357-
help="Number of calibration samples to use from benchmark set (default: 100, set to 0 to disable)")
356+
help="Number of calibration samples to use (default: 100, set to 0 to disable)")
357+
parser.add_argument("--calibration-pdfs", type=str, default=None,
358+
help="'+' separated list of calibration PDF paths (e.g., '/path/to/pdf1.pdf+/path/to/pdf2.pdf'). Required when num-calibration-samples > 0.")
358359

359360
args = parser.parse_args()
360361

362+
# Parse calibration PDFs if provided
363+
calibration_pdfs = None
364+
if args.calibration_pdfs:
365+
calibration_pdfs = args.calibration_pdfs.split('+')
366+
print(f"Parsed {len(calibration_pdfs)} calibration PDF paths")
367+
361368
try:
362-
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples)
369+
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples, calibration_pdfs)
363370
except Exception as e:
364371
print(f"\n❌ Error: {e}")
365372
return 1

0 commit comments

Comments
 (0)