12
12
destination_path: Where to save compressed checkpoint (local or S3)
13
13
recipe_path: Path to quantization config YAML file
14
14
num_calibration_samples: Number of calibration samples to use (default: 100, set to 0 to disable)
15
- calibration_pdfs: '+' separated list of PDF paths to use for calibration (required when num_calibration_samples > 0)
15
+ calibration_pdfs: Glob pattern for PDF paths to use for calibration (required when num_calibration_samples > 0)
16
16
"""
17
17
18
18
import argparse
19
19
import asyncio
20
20
import base64
21
+ import glob
21
22
import json
22
23
import os
23
24
import random
24
25
import shutil
25
26
import tempfile
26
27
from io import BytesIO
28
+ from pathlib import Path
27
29
from typing import Optional , Tuple , Union , List
28
30
29
31
import boto3
@@ -354,11 +356,11 @@ def main():
354
356
# Local to S3
355
357
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed --recipe train/quantization_configs/qwen2vl_w8a8_fp8.yaml
356
358
357
- # Using local calibration PDFs (with glob)
358
- python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/pdfs/doc1.pdf+/data/pdfs/doc2.pdf+/data/pdfs/doc3 .pdf"
359
+ # Using glob pattern for calibration PDFs
360
+ python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/pdfs/* .pdf"
359
361
360
- # Using glob pattern with shell expansion
361
- python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "$(ls /data/pdfs /*.pdf | tr ' \n ' '+') "
362
+ # Using recursive glob pattern
363
+ python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe recipe.yaml --calibration-pdfs "/data/** /*.pdf"
362
364
"""
363
365
)
364
366
parser .add_argument ("source" , help = "Source checkpoint path (local or S3)" )
@@ -367,15 +369,35 @@ def main():
367
369
parser .add_argument ("--num-calibration-samples" , type = int , default = 100 ,
368
370
help = "Number of calibration samples to use (default: 100, set to 0 to disable)" )
369
371
parser .add_argument ("--calibration-pdfs" , type = str , default = None ,
370
- help = "'+' separated list of calibration PDF paths (e.g., '/path/to/pdf1 .pdf+/path/to/pdf2 .pdf'). Required when num-calibration-samples > 0." )
372
+ help = "Glob pattern for calibration PDF paths (e.g., '/path/to/pdfs/* .pdf' or '/data/**/* .pdf'). Required when num-calibration-samples > 0." )
371
373
372
374
args = parser .parse_args ()
373
375
374
376
# Parse calibration PDFs if provided
375
377
calibration_pdfs = None
376
378
if args .calibration_pdfs :
377
- calibration_pdfs = args .calibration_pdfs .split ('+' )
378
- print (f"Parsed { len (calibration_pdfs )} calibration PDF paths" )
379
+ # Use pathlib for better glob handling
380
+ pattern = args .calibration_pdfs
381
+
382
+ # Handle both absolute and relative paths with recursive glob
383
+ if '**' in pattern :
384
+ # For recursive patterns, we need to handle them specially
385
+ if pattern .startswith ('/' ):
386
+ # Absolute path with **
387
+ parts = pattern .split ('**' )
388
+ base_path = Path (parts [0 ])
389
+ glob_pattern = '**' + parts [1 ] if len (parts ) > 1 else '**/*.pdf'
390
+ calibration_pdfs = list (base_path .glob (glob_pattern .lstrip ('/' )))
391
+ else :
392
+ # Relative path with **
393
+ calibration_pdfs = list (Path ('.' ).glob (pattern ))
394
+ calibration_pdfs = [str (p .absolute ()) for p in calibration_pdfs if p .suffix .lower () == '.pdf' ]
395
+ else :
396
+ # Use standard glob for non-recursive patterns
397
+ calibration_pdfs = glob .glob (pattern )
398
+ calibration_pdfs = [p for p in calibration_pdfs if p .lower ().endswith ('.pdf' )]
399
+
400
+ print (f"Found { len (calibration_pdfs )} PDF files matching pattern: { args .calibration_pdfs } " )
379
401
380
402
381
403
compress_checkpoint (args .source , args .destination , args .recipe , args .num_calibration_samples , calibration_pdfs )
0 commit comments