Skip to content

Commit 0aa7479

Browse files
committed
More calibration samples by default
1 parent 6e48012 commit 0aa7479

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
source_path: Path to checkpoint (local or S3)
1212
destination_path: Where to save compressed checkpoint (local or S3)
1313
recipe_path: Path to quantization config YAML file
14-
num_calibration_samples: Number of calibration samples to use (default: 100, set to 0 to disable)
14+
num_calibration_samples: Number of calibration samples to use (default: 256, set to 0 to disable)
1515
calibration_pdfs: Glob pattern for PDF paths to use for calibration (required when num_calibration_samples > 0)
1616
"""
1717

@@ -253,7 +253,7 @@ def data_collator(batch):
253253
return {key: torch.tensor(value) for key, value in batch[0].items()}
254254

255255

256-
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100, calibration_pdfs: Optional[List[str]] = None) -> None:
256+
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 256, calibration_pdfs: Optional[List[str]] = None) -> None:
257257
"""Compress OlmOCR checkpoint using FP8 quantization."""
258258
# Load model and tokenizer
259259
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
@@ -366,8 +366,8 @@ def main():
366366
parser.add_argument("source", help="Source checkpoint path (local or S3)")
367367
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
368368
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
369-
parser.add_argument("--num-calibration-samples", type=int, default=100,
370-
help="Number of calibration samples to use (default: 100, set to 0 to disable)")
369+
parser.add_argument("--num-calibration-samples", type=int, default=256,
370+
help="Number of calibration samples to use (default: 256s, set to 0 to disable)")
371371
parser.add_argument("--calibration-pdfs", type=str, default=None,
372372
help="Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/*.pdf' or '/data/**/*.pdf'). Required when num-calibration-samples > 0.")
373373

0 commit comments

Comments
 (0)