Skip to content

Commit 9d49435

Browse files
kylesayrsdsikka
andauthored
add utils (#998)
Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent bb8660a commit 9d49435

File tree

3 files changed

+107
-1
lines changed

3 files changed

+107
-1
lines changed

src/llmcompressor/utils/helpers.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import ast
7+
import contextlib
78
import errno
89
import fnmatch
910
import glob
@@ -23,6 +24,7 @@
2324

2425
import numpy
2526
import torch
27+
from compressed_tensors.quantization import disable_quantization, enable_quantization
2628
from loguru import logger
2729

2830
__all__ = [
@@ -61,6 +63,8 @@
6163
"import_from_path",
6264
"getattr_chain",
6365
"DisableKVCache",
66+
"DisableQuantization",
67+
"calibration_forward_context",
6468
]
6569

6670

@@ -1080,3 +1084,32 @@ def __enter__(self):
10801084

10811085
def __exit__(self, _exc_type, _exc_val, _exc_tb):
10821086
self.config.use_cache = self.restore_value
1087+
1088+
1089+
@contextlib.contextmanager
1090+
def DisableQuantization(model: torch.nn.Module):
1091+
"""
1092+
Disable quantization from QuantizationModifier
1093+
"""
1094+
model.apply(disable_quantization)
1095+
yield
1096+
model.apply(enable_quantization)
1097+
1098+
1099+
@contextlib.contextmanager
1100+
def calibration_forward_context(model: torch.nn.Module):
1101+
"""
1102+
Context in which all calibration forward passes should occur.
1103+
1104+
- Remove gradient calculations
1105+
- Disable the KV cache
1106+
- Disable quantization from QuantizationModifier
1107+
"""
1108+
model.eval()
1109+
1110+
with (
1111+
torch.no_grad(),
1112+
DisableKVCache(model),
1113+
DisableQuantization(model),
1114+
):
1115+
yield

src/llmcompressor/utils/metric_logging.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import time
12
from typing import List, Tuple
23

4+
import torch
35
from loguru import logger
46
from torch.nn import Module
57

6-
__all__ = ["get_GPU_memory_usage", "get_layer_size_mb"]
8+
__all__ = ["get_GPU_memory_usage", "get_layer_size_mb", "CompressionLogger"]
79

810

911
def get_GPU_memory_usage() -> List[Tuple]:
@@ -51,3 +53,50 @@ def get_layer_size_mb(module: Module) -> float:
5153
total_size_mb = total_size / (1e6) # Convert bytes to MB
5254

5355
return total_size_mb
56+
57+
58+
class CompressionLogger:
59+
"""
60+
Log metrics related to compression algorithm
61+
62+
:param start_tick: time when algorithm started"
63+
:param losses: loss as result of algorithm
64+
"""
65+
66+
def __init__(self, module: torch.nn.Module):
67+
self.module = module
68+
self.start_tick = None
69+
self.loss = None
70+
71+
def set_loss(self, loss: float):
72+
self.loss = loss
73+
74+
def __enter__(self) -> "CompressionLogger":
75+
self.start_tick = time.time()
76+
return self
77+
78+
def __exit__(self, _exc_type, _exc_val, _exc_tb):
79+
stop_tick = time.time()
80+
patch = logger.patch(lambda r: r.update(function="compress"))
81+
82+
if self.start_tick is not None:
83+
duration = stop_tick - self.start_tick
84+
patch.log("METRIC", f"time {duration:.2f}")
85+
if self.loss is not None:
86+
patch.log("METRIC", f"error {self.loss:.2f}")
87+
88+
gpu_usage = get_GPU_memory_usage()
89+
if len(gpu_usage) > 0:
90+
for i in range(len(gpu_usage)):
91+
perc = gpu_usage[i][0] * 100
92+
total_memory = int(gpu_usage[i][1]) # GB
93+
patch.log(
94+
"METRIC",
95+
(
96+
f"GPU {i} | usage: {perc:.2f}%"
97+
f" | total memory: {total_memory} GB"
98+
),
99+
)
100+
101+
compressed_size = get_layer_size_mb(self.module)
102+
patch.log("METRIC", f"Compressed module size: {compressed_size} MB")

tests/llmcompressor/utils/test_helpers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from types import SimpleNamespace
22

33
import pytest
4+
import torch
45

56
from llmcompressor.utils import (
67
ALL_TOKEN,
8+
DisableQuantization,
9+
calibration_forward_context,
710
convert_to_bool,
811
flatten_iterable,
912
getattr_chain,
@@ -124,3 +127,24 @@ def test_getattr_chain():
124127
assert getattr_chain(base, "b.d.dne", "default") == "default"
125128
with pytest.raises(AttributeError):
126129
getattr_chain(base, "b.d.dne")
130+
131+
132+
def test_DisableQuantization():
133+
model = torch.nn.Linear(1, 1)
134+
with DisableQuantization(model):
135+
assert not model.quantization_enabled
136+
assert model.quantization_enabled
137+
138+
139+
def test_calibration_forward_context():
140+
model = torch.nn.Linear(1, 1)
141+
model.config = SimpleNamespace()
142+
model.config.use_cache = True
143+
144+
with calibration_forward_context(model):
145+
assert not torch.is_grad_enabled()
146+
assert not model.quantization_enabled
147+
assert not model.config.use_cache
148+
assert torch.is_grad_enabled()
149+
assert model.quantization_enabled
150+
assert model.config.use_cache

0 commit comments

Comments
 (0)