Skip to content

Commit 07f940c

Browse files
authored
Enable cross-devices Half-Quadratic Quantization for LLMs (#1597)
Signed-off-by: yiliu30 <[email protected]>
1 parent c1f23ce commit 07f940c

File tree

24 files changed

+1736
-7
lines changed

24 files changed

+1736
-7
lines changed

neural_compressor/torch/algorithms/weight_only/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
from .rtn import rtn_quantize
1616
from .gptq import gptq_quantize
1717
from .awq import awq_quantize
18+
from .hqq import hqq_quantize
1819
from .modules import WeightOnlyLinear
1920
from .utility import *
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .quantizer import HQQuantizer
16+
from .config import HQQModuleConfig, QTensorConfig
17+
from .quant_api import hqq_quantize
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) 2023-2024 Microsoft Corporation and Intel Corporation
2+
3+
# This code is based on Microsoft Corporation's DeepSpeed library and
4+
# the accelerators implementation in this library. It has been modified
5+
# from its original forms to simplify and adapt it for use in
6+
# the Intel® Neural Compressor.
7+
8+
# Licensed under the Apache License, Version 2.0 (the "License");
9+
# you may not use this file except in compliance with the License.
10+
# You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing, software
15+
# distributed under the License is distributed on an "AS IS" BASIS,
16+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
# See the License for the specific language governing permissions and
18+
# limitations under the License.
19+
20+
# NOTICE: The design adapted from:
21+
# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py.
22+
# TODO: move it into torch/utils
23+
24+
25+
# To keep it simply, only add the APIs we need.
26+
27+
import os
28+
from abc import ABC, abstractmethod
29+
from typing import Any, Callable, List
30+
31+
import torch
32+
33+
from neural_compressor.torch.utils import logger
34+
35+
PRIORITY_CUDA = 100
36+
PRIORITY_CPU = 90
37+
38+
39+
class AcceleratorRegistry:
40+
registered_accelerators = {}
41+
42+
@classmethod
43+
def register_accelerator_impl(cls, name: str, priority: float = 0):
44+
"""Register new accelerator implementation.
45+
46+
Usage example:
47+
@AcceleratorRegistry.register_accelerator(name="cpu", priority=100)
48+
class CPU_Accelerator:
49+
...
50+
51+
Args:
52+
name: the accelerator name.
53+
priority: priority: the priority of the accelerator. A larger number indicates a higher priority,
54+
"""
55+
56+
def decorator(accelerator_cls):
57+
cls.registered_accelerators.setdefault(name, {})
58+
cls.registered_accelerators[name] = (accelerator_cls, priority)
59+
return accelerator_cls
60+
61+
return decorator
62+
63+
@classmethod
64+
def get_sorted_accelerators(cls) -> List["Auto_Accelerator"]:
65+
"""Get registered accelerators sorted by priority."""
66+
accelerator_pairs = cls.registered_accelerators.values()
67+
sorted_accelerators_pairs = sorted(accelerator_pairs, key=lambda x: x[1], reverse=True)
68+
sorted_accelerators = [pair[0] for pair in sorted_accelerators_pairs]
69+
return sorted_accelerators
70+
71+
@classmethod
72+
def get_accelerator_cls_by_name(cls, name: str) -> "Auto_Accelerator":
73+
"""Get accelerator by name."""
74+
accelerator_cls, _ = cls.registered_accelerators.get(name, (None, None))
75+
return accelerator_cls
76+
77+
78+
accelerator_registry = AcceleratorRegistry()
79+
80+
81+
def register_accelerator(name: str, priority: float = 0) -> Callable[..., Any]:
82+
"""Register new accelerator.
83+
84+
Usage example:
85+
@register_accelerator(name="cuda", priority=100)
86+
class CUDA_Accelerator:
87+
...
88+
89+
Args:
90+
name: the accelerator name.
91+
priority: the priority of the accelerator. A larger number indicates a higher priority,
92+
"""
93+
94+
return accelerator_registry.register_accelerator_impl(name=name, priority=priority)
95+
96+
97+
class Auto_Accelerator(ABC):
98+
@classmethod
99+
@abstractmethod
100+
def is_available(cls) -> bool:
101+
pass
102+
103+
@abstractmethod
104+
def name(self) -> str:
105+
pass
106+
107+
@abstractmethod
108+
def device_name(self, device_indx) -> str:
109+
pass
110+
111+
@abstractmethod
112+
def set_device(self, device_index):
113+
pass
114+
115+
@abstractmethod
116+
def current_device(self):
117+
pass
118+
119+
@abstractmethod
120+
def current_device_name(self):
121+
pass
122+
123+
@abstractmethod
124+
def device(self, device_index=None):
125+
pass
126+
127+
@abstractmethod
128+
def empty_cache(self):
129+
pass
130+
131+
@abstractmethod
132+
def synchronize(self):
133+
pass
134+
135+
136+
@register_accelerator(name="cpu", priority=PRIORITY_CPU)
137+
class CPU_Accelerator(Auto_Accelerator):
138+
def __init__(self) -> None:
139+
self._name = "cpu"
140+
141+
def name(self) -> str:
142+
return self._name
143+
144+
@classmethod
145+
def is_available(cls) -> bool:
146+
return True
147+
148+
def device_name(self, device_indx) -> str:
149+
return "cpu"
150+
151+
def set_device(self, device_index):
152+
pass
153+
154+
def current_device(self):
155+
return "cpu"
156+
157+
def current_device_name(self):
158+
return "cpu"
159+
160+
def device(self, device_index=None):
161+
pass
162+
163+
def empty_cache(self):
164+
pass
165+
166+
def synchronize(self):
167+
pass
168+
169+
170+
@register_accelerator(name="cuda", priority=PRIORITY_CUDA)
171+
class CUDA_Accelerator(Auto_Accelerator):
172+
def __init__(self) -> None:
173+
self._name = "cuda"
174+
175+
def name(self) -> str:
176+
return self._name
177+
178+
@classmethod
179+
def is_available(cls) -> bool:
180+
return torch.cuda.is_available()
181+
182+
def device_name(self, device_indx) -> str:
183+
if device_indx is None:
184+
return "cuda"
185+
return f"cuda:{device_indx}"
186+
187+
def synchronize(self):
188+
return torch.cuda.synchronize()
189+
190+
def set_device(self, device_index):
191+
return torch.cuda.set_device(device_index)
192+
193+
def current_device(self):
194+
return torch.cuda.current_device()
195+
196+
def current_device_name(self):
197+
return "cuda:{}".format(torch.cuda.current_device())
198+
199+
def device(self, device_index=None):
200+
return torch.cuda.device(device_index)
201+
202+
def empty_cache(self):
203+
return torch.cuda.empty_cache()
204+
205+
206+
def auto_detect_accelerator() -> Auto_Accelerator:
207+
# if runtime_accelerator.accelerator:
208+
# return runtime_accelerator.accelerator
209+
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
210+
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
211+
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
212+
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
213+
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
214+
if accelerator_cls.is_available():
215+
logger.debug("Auto detect accelerator: %s.", accelerator_cls.__name__)
216+
accelerator = accelerator_cls()
217+
return accelerator
218+
219+
220+
# Force use cpu accelerator even if cuda is available.
221+
# FORCE_DEVICE = "cpu" python ...
222+
# or
223+
# CUDA_VISIBLE_DEVICES="" python ...
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) 2023-2024 Mobiusml and Intel Corporation
2+
3+
# This code is based on Mobiusml's HQQ library. It has been modified
4+
# from its original forms to simplify and adapt it for use in
5+
# the Intel® Neural Compressor.
6+
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
# Notice: Copied from from https://github.com/mobiusml/hqq
20+
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2023
21+
#####################################################
22+
23+
import numpy as np
24+
import torch
25+
26+
from .utility import is_divisible
27+
28+
__all__ = ["Packer"]
29+
30+
31+
# Bit packing logic. format: pack/unpack_nBits_target-<uint8 or int32>
32+
class BitPack:
33+
# 8-bit
34+
################################################
35+
@staticmethod
36+
def pack_8bit_u8(W_q):
37+
return W_q.to(torch.uint8)
38+
39+
@staticmethod
40+
def unpack_8bit_u8(W_q):
41+
return W_q
42+
43+
# 4-bit
44+
################################################
45+
@staticmethod
46+
def pack_4bit_u8(W_q): # uint8 > uint8/2
47+
W_q = W_q.to(torch.uint8)
48+
_step = int(len(W_q) / 2)
49+
return (W_q[:_step] << 4) | W_q[_step:]
50+
51+
# A bit faster than the _cat version
52+
@staticmethod
53+
def unpack_4bit_u8(W_q): # uint8/2 > uint8
54+
_step = W_q.shape[0]
55+
tmp = torch.empty([2 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
56+
tmp[:_step] = (W_q & 0b11110000) >> 4
57+
tmp[_step:] = W_q & 0b00001111
58+
return tmp
59+
60+
# 2-bit
61+
################################################
62+
@staticmethod
63+
def pack_2bit_u8(W_q): # uint8 > uint8/4
64+
W_q = W_q.to(torch.uint8)
65+
_step = int(len(W_q) / 4)
66+
return W_q[:_step] << 6 | W_q[_step : 2 * _step] << 4 | W_q[2 * _step : 3 * _step] << 2 | W_q[3 * _step :]
67+
68+
# A bit faster than the _cat version
69+
@staticmethod
70+
def unpack_2bit_u8(W_q):
71+
_step = W_q.shape[0]
72+
tmp = torch.empty([4 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
73+
tmp[:_step] = (W_q & 0b11000000) >> 6
74+
tmp[_step : 2 * _step] = (W_q & 0b00110000) >> 4
75+
tmp[2 * _step : 3 * _step] = (W_q & 0b00001100) >> 2
76+
tmp[3 * _step :] = W_q & 0b00000011
77+
return tmp
78+
79+
# 3bit
80+
################################################
81+
@staticmethod
82+
def pack_3bit_32(W_q_in):
83+
W_q = torch.zeros(
84+
[int(10 * np.ceil(W_q_in.shape[0] / 10.0)), W_q_in.shape[1]], device=W_q_in.device, dtype=torch.int32
85+
)
86+
W_q[: len(W_q_in)] = W_q_in
87+
_step = int(len(W_q) / 10)
88+
W_q = (
89+
(W_q[:_step] << 27)
90+
| (W_q[_step : _step * 2] << 24)
91+
| (W_q[_step * 2 : _step * 3] << 21)
92+
| (W_q[_step * 3 : _step * 4] << 18)
93+
| (W_q[_step * 4 : _step * 5] << 15)
94+
| (W_q[_step * 5 : _step * 6] << 12)
95+
| (W_q[_step * 6 : _step * 7] << 9)
96+
| (W_q[7 * _step : _step * 8] << 6)
97+
| (W_q[_step * 8 : _step * 9] << 3)
98+
| (W_q[_step * 9 :])
99+
)
100+
return W_q
101+
102+
# A bit faster than _cat version
103+
@staticmethod
104+
def unpack_3bit_32(W_q):
105+
_step = W_q.shape[0]
106+
tmp = torch.empty([10 * _step, W_q.shape[1]], dtype=torch.uint8, device=W_q.device)
107+
tmp[:_step] = (W_q & 0b00111000000000000000000000000000) >> 27
108+
tmp[1 * _step : 2 * _step] = (W_q & 0b00000111000000000000000000000000) >> 24
109+
tmp[2 * _step : 3 * _step] = (W_q & 0b00000000111000000000000000000000) >> 21
110+
tmp[3 * _step : 4 * _step] = (W_q & 0b00000000000111000000000000000000) >> 18
111+
tmp[4 * _step : 5 * _step] = (W_q & 0b00000000000000111000000000000000) >> 15
112+
tmp[5 * _step : 6 * _step] = (W_q & 0b00000000000000000111000000000000) >> 12
113+
tmp[6 * _step : 7 * _step] = (W_q & 0b00000000000000000000111000000000) >> 9
114+
tmp[7 * _step : 8 * _step] = (W_q & 0b00000000000000000000000111000000) >> 6
115+
tmp[8 * _step : 9 * _step] = (W_q & 0b00000000000000000000000000111000) >> 3
116+
tmp[9 * _step :] = W_q & 0b00000000000000000000000000000111
117+
return tmp
118+
119+
120+
class Packer:
121+
# TODO: Refine the packer
122+
bit_to_packing = {8: "8bit_u8", 4: "4bit_u8", 3: "3bit_32", 2: "2bit_u8"}
123+
124+
pack_fn_mapping = {
125+
"8bit_u8": BitPack.pack_8bit_u8,
126+
"4bit_u8": BitPack.pack_4bit_u8,
127+
"3bit_32": BitPack.pack_3bit_32,
128+
"2bit_u8": BitPack.pack_2bit_u8,
129+
}
130+
131+
unpack_fn_mapping = {
132+
"8bit_u8": BitPack.unpack_8bit_u8,
133+
"4bit_u8": BitPack.unpack_4bit_u8,
134+
"3bit_32": BitPack.unpack_3bit_32,
135+
"2bit_u8": BitPack.unpack_2bit_u8,
136+
}
137+
138+
@staticmethod
139+
def get_pack_fn(nbits: int):
140+
return Packer.pack_fn_mapping[Packer.bit_to_packing[nbits]]
141+
142+
@staticmethod
143+
def get_unpack_fn(nbits: int):
144+
return Packer.unpack_fn_mapping[Packer.bit_to_packing[nbits]]

0 commit comments

Comments
 (0)