6
6
3. Saves compressed model to destination (local or S3)
7
7
8
8
Usage:
9
- python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N]
9
+ python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N] [--calibration-pdfs PDF1+PDF2+...]
10
10
11
11
source_path: Path to checkpoint (local or S3)
12
12
destination_path: Where to save compressed checkpoint (local or S3)
13
13
recipe_path: Path to quantization config YAML file
14
- num_calibration_samples: Number of calibration samples to use (default: 100)
14
+ num_calibration_samples: Number of calibration samples to use (default: 100, set to 0 to disable)
15
+ calibration_pdfs: '+' separated list of PDF paths to use for calibration (required when num_calibration_samples > 0)
15
16
"""
16
17
17
18
import argparse
36
37
37
38
38
39
s3_client = boto3 .client ("s3" )
39
- CALIBRATION_S3_PATH = "s3://ai2-oe-data/jakep/olmocr/olmOCR-mix-0225/benchmark_set"
40
40
41
41
42
- def download_calibration_pdfs (num_samples : int ) -> List [str ]:
43
- """Download calibration PDFs from S3 and return local paths."""
44
- bucket , prefix = parse_s3_path (CALIBRATION_S3_PATH )
42
+ def get_calibration_pdfs (num_samples : int , pdf_paths : List [str ]) -> List [str ]:
43
+ """Get calibration PDFs from provided paths.
45
44
46
- # Create temporary directory for PDFs
47
- temp_dir = tempfile .mkdtemp ()
48
- print (f"Downloading calibration PDFs to { temp_dir } ..." )
49
-
50
- # List all PDFs in the calibration dataset
51
- paginator = s3_client .get_paginator ("list_objects_v2" )
52
- pages = paginator .paginate (Bucket = bucket , Prefix = prefix )
45
+ Args:
46
+ num_samples: Number of samples to use
47
+ pdf_paths: List of local PDF paths
48
+
49
+ Returns:
50
+ List of valid PDF paths
51
+ """
52
+ print (f"Using { len (pdf_paths )} provided calibration PDFs" )
53
53
54
- pdf_keys = []
55
- for page in pages :
56
- for obj in page .get ("Contents" , []):
57
- key = obj ["Key" ]
58
- if key .endswith (".pdf" ):
59
- pdf_keys .append (key )
54
+ # If more PDFs provided than needed, randomly sample
55
+ if len (pdf_paths ) > num_samples :
56
+ pdf_paths = random .sample (pdf_paths , num_samples )
57
+ print (f"Randomly sampled { num_samples } PDFs from provided paths" )
60
58
61
- # Randomly sample PDFs
62
- if len (pdf_keys ) > num_samples :
63
- pdf_keys = random .sample (pdf_keys , num_samples )
59
+ # Verify all PDFs exist
60
+ valid_paths = []
61
+ for path in pdf_paths :
62
+ if os .path .exists (path ) and path .endswith ('.pdf' ):
63
+ valid_paths .append (path )
64
+ else :
65
+ print (f" Warning: Skipping invalid path: { path } " )
64
66
65
- # Download the PDFs
66
- local_paths = []
67
- for key in pdf_keys :
68
- filename = os .path .basename (key )
69
- local_path = os .path .join (temp_dir , filename )
70
- s3_client .download_file (bucket , key , local_path )
71
- local_paths .append (local_path )
72
- print (f" Downloaded { filename } " )
67
+ if not valid_paths :
68
+ raise ValueError ("No valid PDF paths found in the provided list" )
73
69
74
- print (f"Downloaded { len (local_paths ) } calibration PDFs" )
75
- return local_paths , temp_dir
70
+ print (f"Using { len (valid_paths ) } valid calibration PDFs" )
71
+ return valid_paths
76
72
77
73
78
74
async def prepare_calibration_dataset (pdf_paths : List [str ], processor ) -> List [dict ]:
@@ -243,7 +239,7 @@ def data_collator(batch):
243
239
return {key : torch .tensor (value ) for key , value in batch [0 ].items ()}
244
240
245
241
246
- def compress_checkpoint (source_path : str , dest_path : str , recipe_path : str , num_calibration_samples : int = 100 ) -> None :
242
+ def compress_checkpoint (source_path : str , dest_path : str , recipe_path : str , num_calibration_samples : int = 100 , calibration_pdfs : Optional [ List [ str ]] = None ) -> None :
247
243
"""Compress OlmOCR checkpoint using FP8 quantization."""
248
244
# Load model and tokenizer
249
245
model , tokenizer , temp_source_dir = load_model_and_tokenizer (source_path )
@@ -257,16 +253,18 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
257
253
258
254
# Prepare calibration dataset if requested
259
255
dataset = None
260
- temp_pdf_dir = None
261
256
262
257
if num_calibration_samples > 0 :
258
+ if not calibration_pdfs :
259
+ raise ValueError ("Calibration PDFs must be provided when num_calibration_samples > 0. Use --calibration-pdfs argument." )
260
+
263
261
print (f"\n Preparing calibration dataset with { num_calibration_samples } samples..." )
264
262
265
263
# Load processor for the model
266
264
processor = AutoProcessor .from_pretrained (source_path if not temp_source_dir else temp_source_dir )
267
265
268
- # Download PDFs
269
- pdf_paths , temp_pdf_dir = download_calibration_pdfs (num_calibration_samples )
266
+ # Get calibration PDFs from provided paths
267
+ pdf_paths = get_calibration_pdfs (num_calibration_samples , calibration_pdfs )
270
268
271
269
# Prepare dataset
272
270
dataset = asyncio .run (prepare_calibration_dataset (pdf_paths , processor ))
@@ -321,11 +319,6 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_
321
319
print (f"Cleaning up temporary directory { temp_source_dir } ..." )
322
320
shutil .rmtree (temp_source_dir )
323
321
324
- # Clean up temporary PDF directory if needed
325
- if temp_pdf_dir :
326
- print (f"Cleaning up temporary PDF directory { temp_pdf_dir } ..." )
327
- shutil .rmtree (temp_pdf_dir )
328
-
329
322
# Free up GPU memory
330
323
del model
331
324
torch .cuda .empty_cache ()
@@ -348,18 +341,32 @@ def main():
348
341
349
342
# Local to S3
350
343
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed --recipe train/quantization_configs/qwen2vl_w8a8_fp8.yaml
344
+
345
+ # Using local calibration PDFs (with glob)
346
+ python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/pdfs/doc1.pdf+/data/pdfs/doc2.pdf+/data/pdfs/doc3.pdf"
347
+
348
+ # Using glob pattern with shell expansion
349
+ python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "$(ls /data/pdfs/*.pdf | tr '\n ' '+')"
351
350
"""
352
351
)
353
352
parser .add_argument ("source" , help = "Source checkpoint path (local or S3)" )
354
353
parser .add_argument ("destination" , help = "Destination path for compressed checkpoint (local or S3)" )
355
354
parser .add_argument ("--recipe" , required = True , help = "Path to quantization recipe YAML file" )
356
355
parser .add_argument ("--num-calibration-samples" , type = int , default = 100 ,
357
- help = "Number of calibration samples to use from benchmark set (default: 100, set to 0 to disable)" )
356
+ help = "Number of calibration samples to use (default: 100, set to 0 to disable)" )
357
+ parser .add_argument ("--calibration-pdfs" , type = str , default = None ,
358
+ help = "'+' separated list of calibration PDF paths (e.g., '/path/to/pdf1.pdf+/path/to/pdf2.pdf'). Required when num-calibration-samples > 0." )
358
359
359
360
args = parser .parse_args ()
360
361
362
+ # Parse calibration PDFs if provided
363
+ calibration_pdfs = None
364
+ if args .calibration_pdfs :
365
+ calibration_pdfs = args .calibration_pdfs .split ('+' )
366
+ print (f"Parsed { len (calibration_pdfs )} calibration PDF paths" )
367
+
361
368
try :
362
- compress_checkpoint (args .source , args .destination , args .recipe , args .num_calibration_samples )
369
+ compress_checkpoint (args .source , args .destination , args .recipe , args .num_calibration_samples , calibration_pdfs )
363
370
except Exception as e :
364
371
print (f"\n ❌ Error: { e } " )
365
372
return 1
0 commit comments