Skip to content

Commit 3f9fc8b

Browse files
committed
Better compressor hopefully
1 parent 287c827 commit 3f9fc8b

File tree

3 files changed

+24
-18
lines changed

3 files changed

+24
-18
lines changed

olmocr/train/compress_checkpoint.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
3. Saves compressed model to destination (local or S3)
77
88
Usage:
9-
python compress_checkpoint.py <source_path> <destination_path>
9+
python compress_checkpoint.py <source_path> <destination_path> [--recipe <recipe_path>]
1010
1111
source_path: Path to checkpoint (local or S3)
1212
destination_path: Where to save compressed checkpoint (local or S3)
13+
recipe_path: Optional path to quantization config YAML file
1314
"""
1415

1516
import argparse
@@ -22,7 +23,6 @@
2223
import boto3
2324
import torch
2425
from llmcompressor import oneshot
25-
from llmcompressor.modifiers.quantization import QuantizationModifier
2626
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
2727

2828
from olmocr.s3_utils import parse_s3_path
@@ -150,7 +150,7 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt
150150
shutil.copy2(source_file, dest_file)
151151

152152

153-
def compress_checkpoint(source_path: str, dest_path: str) -> None:
153+
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> None:
154154
"""Compress OlmOCR checkpoint using FP8 quantization."""
155155
# Load model and tokenizer
156156
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
@@ -162,16 +162,9 @@ def compress_checkpoint(source_path: str, dest_path: str) -> None:
162162
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
163163
print("=========================\n")
164164

165-
# Configure FP8 dynamic quantization
166-
print("\nApplying FP8 dynamic quantization...")
167-
recipe = QuantizationModifier(
168-
targets="Linear",
169-
scheme="FP8_DYNAMIC",
170-
ignore=["re:.*lm_head", "re:visual.*"],
171-
)
172-
173-
# Apply the quantization
174-
oneshot(model=model, recipe=recipe)
165+
# Apply quantization using provided recipe
166+
print(f"\nApplying quantization using recipe: {recipe_path}")
167+
oneshot(model=model, recipe=recipe_path)
175168
print("✓ Quantization completed successfully")
176169

177170
# Save the compressed model
@@ -218,25 +211,26 @@ def main():
218211
epilog="""
219212
Examples:
220213
# Local to local
221-
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed
214+
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed --recipe train/quantization_configs/qwen2_5vl_w8a8_fp8.yaml
222215
223216
# S3 to S3
224-
python compress_checkpoint.py s3://bucket/checkpoint s3://bucket/compressed
217+
python compress_checkpoint.py s3://bucket/checkpoint s3://bucket/compressed --recipe train/quantization_configs/qwen2vl_w8a8_fp8.yaml
225218
226219
# S3 to local
227-
python compress_checkpoint.py s3://bucket/checkpoint /path/to/compressed
220+
python compress_checkpoint.py s3://bucket/checkpoint /path/to/compressed --recipe train/quantization_configs/qwen2_5vl_w8a8_fp8.yaml
228221
229222
# Local to S3
230-
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed
223+
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed --recipe train/quantization_configs/qwen2vl_w8a8_fp8.yaml
231224
"""
232225
)
233226
parser.add_argument("source", help="Source checkpoint path (local or S3)")
234227
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
228+
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
235229

236230
args = parser.parse_args()
237231

238232
try:
239-
compress_checkpoint(args.source, args.destination)
233+
compress_checkpoint(args.source, args.destination, args.recipe)
240234
except Exception as e:
241235
print(f"\n❌ Error: {e}")
242236
return 1
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
default_stage:
2+
default_modifiers:
3+
QuantizationModifier:
4+
targets: [Linear]
5+
ignore: ['re:.*lm_head', 're:model.visual.*']
6+
scheme: FP8_DYNAMIC
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
default_stage:
2+
default_modifiers:
3+
QuantizationModifier:
4+
targets: [Linear]
5+
ignore: ['re:.*lm_head', 're:visual.*']
6+
scheme: FP8_DYNAMIC

0 commit comments

Comments
 (0)