Skip to content

Commit 2a86aea

Browse files
authored
Support auto device for GPTQ and RTN (#1622)
Signed-off-by: yiliu30 <[email protected]>
1 parent 071ab31 commit 2a86aea

File tree

10 files changed

+277
-17
lines changed

10 files changed

+277
-17
lines changed

neural_compressor/torch/algorithms/weight_only/gptq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
import transformers
2929
from tqdm import tqdm
3030

31-
from neural_compressor.torch.utils import fetch_module, logger, set_module
31+
from neural_compressor.torch.utils import fetch_module, get_device, logger, set_module
32+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
3233

3334
from .modules import WeightOnlyLinear
3435

@@ -255,7 +256,7 @@ def __init__(
255256
self.check_layer_config()
256257

257258
# device
258-
self.device = device
259+
self.device = get_device(kwargs.pop("device", "auto"))
259260
if str(self.model.device).startswith("cuda"):
260261
self.device = self.model.device
261262
self.is_ready = False

neural_compressor/torch/algorithms/weight_only/hqq/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import torch
2525

2626
from neural_compressor.torch.utils import logger
27+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
2728

28-
from .auto_accelerator import auto_detect_accelerator
2929
from .bitpack import Packer
3030
from .config import HQQModuleConfig, QTensorConfig, default_hqq_module_config, hqq_global_option
3131
from .optimizer import optimize_weights_proximal

neural_compressor/torch/algorithms/weight_only/hqq/optimizer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import torch
2121

2222
from neural_compressor.torch.utils import logger
23-
24-
from .auto_accelerator import auto_detect_accelerator
23+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
2524

2625

2726
# Proximal solver || W - dequantize(quantize(W))||_p^p

neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import torch
1818

1919
from neural_compressor.torch.utils import logger
20+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
2021

21-
from .auto_accelerator import auto_detect_accelerator
2222
from .config import ConfigMappingType, default_hqq_module_config, hqq_global_option
2323
from .core import HQQLinear
2424

neural_compressor/torch/algorithms/weight_only/rtn.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
import torch
2323

24-
from neural_compressor.torch.utils import logger, set_module
24+
from neural_compressor.torch.utils import get_device, logger, set_module
25+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
2526

2627
from .utility import quant_tensor, search_clip
2728

@@ -73,7 +74,12 @@ def rtn_quantize(
7374
Returns:
7475
model: fake quantized torch module
7576
"""
76-
device = "cpu"
77+
device = get_device(kwargs.pop("device", "auto"))
78+
79+
# Put model on device explicitly
80+
# TODO: refine it later, Put module on device one by one instead of the whole model
81+
model.to(device)
82+
7783
assert isinstance(model, torch.nn.Module), "only support torch module"
7884
supported_layers = ["Linear"]
7985
# initialize global configuration
@@ -94,6 +100,7 @@ def rtn_quantize(
94100
dtype = weight_config[name].get("dtype", "int")
95101
if dtype == "fp32":
96102
continue
103+
logger.debug("Apply RTN on module %s.", name)
97104
bits = weight_config[name].get("bits", 4)
98105
group_size = weight_config[name]["group_size"]
99106
scheme = weight_config[name]["scheme"]
Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,23 @@ def empty_cache(self):
203203
return torch.cuda.empty_cache()
204204

205205

206-
def auto_detect_accelerator() -> Auto_Accelerator:
207-
# if runtime_accelerator.accelerator:
208-
# return runtime_accelerator.accelerator
206+
def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
207+
# The environment variable `FORCE_DEVICE` has higher priority than the `device_name`.
208+
# TODO: refine the docs and logic later
209209
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
210210
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
211211
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
212212
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
213+
if device_name != "auto":
214+
if accelerator_registry.get_accelerator_cls_by_name(device_name) is not None:
215+
accelerator_cls = accelerator_registry.get_accelerator_cls_by_name(device_name)
216+
logger.warning("Selected accelerator %s by device_name.", accelerator_cls.__name__)
217+
return accelerator_cls()
218+
else:
219+
logger.warning("The device name %s is not supported, use auto detect instead.", device_name)
213220
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
214221
if accelerator_cls.is_available():
215-
logger.debug("Auto detect accelerator: %s.", accelerator_cls.__name__)
222+
logger.warning("Auto detect accelerator: %s.", accelerator_cls.__name__)
216223
accelerator = accelerator_cls()
217224
return accelerator
218225

neural_compressor/torch/utils/environ.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,11 @@ def get_torch_version():
6161
assert False, "Got an unknown version of torch: {}".format(e)
6262
version = Version(torch_version)
6363
return version
64+
65+
66+
def get_device(device_name="auto"):
67+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
68+
69+
runtime_accelerator = auto_detect_accelerator(device_name)
70+
device = runtime_accelerator.name()
71+
return device

test/3x/torch/quantization/weight_only/hqq/test_hqq_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import torch
55
from transformers import AutoModelForCausalLM
66

7-
from neural_compressor.torch.algorithms.weight_only.hqq.auto_accelerator import auto_detect_accelerator
87
from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option
98
from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear
109
from neural_compressor.torch.algorithms.weight_only.hqq.utility import see_cuda_memory_usage
10+
from neural_compressor.torch.utils.auto_accelerator import auto_detect_accelerator
1111

1212

1313
def _common_cuda_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128):

test/3x/torch/quantization/weight_only/hqq/test_auto_accelerator.py renamed to test/3x/torch/quantization/weight_only/test_auto_accelerator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,7 @@
33
import pytest
44
import torch
55

6-
from neural_compressor.torch.algorithms.weight_only.hqq.auto_accelerator import (
7-
accelerator_registry,
8-
auto_detect_accelerator,
9-
)
6+
from neural_compressor.torch.utils.auto_accelerator import accelerator_registry, auto_detect_accelerator
107

118

129
class Test_CPU_Accelerator:
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
import random
2+
3+
import pytest
4+
import torch
5+
import transformers
6+
from tqdm import tqdm
7+
8+
from neural_compressor.common.utils import logger
9+
from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device
10+
from neural_compressor.torch.quantization import GPTQConfig, get_default_rtn_config, quantize
11+
12+
13+
class GPTQDataloaderPreprocessor:
14+
def __init__(self, dataloader_original, use_max_length=False, max_seq_length=2048, nsamples=128):
15+
self.dataloader_original = dataloader_original
16+
self.use_max_length = use_max_length
17+
self.max_seq_length = max_seq_length
18+
self.nsamples = nsamples
19+
self.dataloader = []
20+
self.is_ready = False
21+
22+
def get_prepared_dataloader(self):
23+
if not self.is_ready:
24+
self.prepare_dataloader()
25+
return self.dataloader
26+
27+
def prepare_dataloader(self):
28+
if self.use_max_length:
29+
# (Recommend) only take sequence whose length exceeds self.max_seq_length,
30+
# which preserves calibration's tokens are all valid
31+
# This is GPTQ official dataloader implementation
32+
self.obtain_first_n_samples_fulllength()
33+
else:
34+
# general selection, no padding, not GPTQ original implementation.
35+
self.obtain_first_n_samples()
36+
self.is_ready = True
37+
38+
def obtain_first_n_samples(self, seed=0):
39+
"""Get first nsample data as the real calibration dataset."""
40+
self.dataloader.clear()
41+
random.seed(seed)
42+
for batch in self.dataloader_original:
43+
# process data, depends on its data type.
44+
if len(self.dataloader) == self.nsamples:
45+
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
46+
break
47+
# list, tuple
48+
if isinstance(batch, list) or isinstance(batch, tuple):
49+
if batch[0].shape[-1] > self.max_seq_length:
50+
i = random.randint(0, batch[0].shape[-1] - self.max_seq_length - 1)
51+
j = i + self.max_seq_length
52+
batch_final = []
53+
for item in batch:
54+
if isinstance(item, torch.Tensor) and item.shape.__len__() == 2:
55+
batch_final.append(item[:, i:j])
56+
else:
57+
batch_final.append(item)
58+
else:
59+
batch_final = batch[:]
60+
# dict
61+
elif isinstance(batch, dict):
62+
try:
63+
length = batch["input_ids"].shape[-1]
64+
except:
65+
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
66+
continue
67+
batch_final = {}
68+
if length > self.max_seq_length:
69+
i = random.randint(0, length - self.max_seq_length - 1)
70+
j = i + self.max_seq_length
71+
# may have to slice every sequence related data
72+
for key in batch.keys():
73+
if isinstance(batch[key], torch.Tensor):
74+
batch_final[key] = batch[key][:, i:j] # slice on sequence length dim
75+
else:
76+
batch_final[key] = batch[key]
77+
else:
78+
batch_final = batch
79+
# tensor
80+
else:
81+
if batch.shape[-1] > self.max_seq_length:
82+
i = random.randint(0, batch.shape[-1] - self.max_seq_length - 1)
83+
j = i + self.max_seq_length
84+
batch_final = batch[:, i:j]
85+
else:
86+
batch_final = batch
87+
self.dataloader.append(batch_final)
88+
89+
if len(self.dataloader) < self.nsamples:
90+
logger.warning(f"Try to use {self.nsamples} data, but entire dataset size is {len(self.dataloader)}.")
91+
92+
def obtain_first_n_samples_fulllength(self, seed=0):
93+
self.dataloader.clear()
94+
random.seed(seed)
95+
unified_length = self.max_seq_length
96+
for batch in self.dataloader_original:
97+
if len(self.dataloader) == self.nsamples:
98+
logger.info(f"Successfully collect {self.nsamples} calibration samples.")
99+
break
100+
# list & tuple, gpt-j-6b mlperf, etc.
101+
if isinstance(batch, list) or isinstance(batch, tuple):
102+
if batch[0].shape[-1] == unified_length:
103+
batch_final = batch[:]
104+
elif batch[0].shape[-1] > unified_length:
105+
i = random.randint(0, batch[0].shape[-1] - unified_length - 1)
106+
j = i + unified_length
107+
batch_final = []
108+
for item in batch:
109+
if isinstance(item, torch.Tensor) and item.shape.__len__() == 2:
110+
batch_final.append(item[:, i:j])
111+
else:
112+
batch_final.append(item)
113+
else:
114+
# not match max length, not include in target dataset
115+
continue
116+
# dict
117+
elif isinstance(batch, dict):
118+
try:
119+
length = batch["input_ids"].shape[-1]
120+
except:
121+
logger.warning("Please make sure your dict'like data contains key of 'input_ids'.")
122+
continue
123+
batch_final = {}
124+
if length == self.max_seq_length:
125+
batch_final = batch
126+
elif length > self.max_seq_length:
127+
i = random.randint(0, length - self.max_seq_length - 1)
128+
j = i + self.max_seq_length
129+
# may have to slice every sequence related data
130+
for key in batch.keys():
131+
if isinstance(batch[key], torch.Tensor):
132+
batch_final[key] = batch[key][:, i:j] # slice on sequence length dim with same position
133+
else:
134+
batch_final[key] = batch[key]
135+
else:
136+
# not match max length, not include in target dataset
137+
continue
138+
# tensor
139+
else:
140+
if batch.shape[-1] == unified_length:
141+
batch_final = batch
142+
elif batch.shape[-1] > unified_length:
143+
i = random.randint(0, batch.shape[-1] - unified_length - 1)
144+
j = i + unified_length
145+
batch_final = batch[:, i:j]
146+
else:
147+
# not match max length, not include in target dataset
148+
continue
149+
self.dataloader.append(batch_final)
150+
if len(self.dataloader) < self.nsamples: # pragma: no cover
151+
logger.warning(
152+
f"Trying to allocate {self.nsamples} data with fixed length {unified_length}, \
153+
but only {len(self.dataloader)} samples are found. Please use smaller 'self.max_seq_length' value."
154+
)
155+
156+
157+
class TestGPTQ:
158+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
159+
def test_GPTQ_fixed_length_quant(self):
160+
class GPTQLLMDataLoader:
161+
def __init__(self):
162+
self.batch_size = 1
163+
164+
def __iter__(self):
165+
for i in range(10):
166+
yield torch.ones([1, 512], dtype=torch.long)
167+
168+
class GPTQLLMDataLoaderList:
169+
def __init__(self):
170+
self.batch_size = 1
171+
172+
def __iter__(self):
173+
for i in range(10):
174+
yield (torch.ones([1, 512], dtype=torch.long), torch.ones([1, 512], dtype=torch.long))
175+
176+
class GPTQLLMDataLoaderDict:
177+
def __init__(self):
178+
self.batch_size = 1
179+
180+
def __iter__(self):
181+
for i in range(10):
182+
yield {
183+
"input_ids": torch.ones([1, 512], dtype=torch.long),
184+
"attention_mask": torch.ones([1, 512], dtype=torch.long),
185+
}
186+
187+
dataloader_list = GPTQLLMDataLoaderList()
188+
dataloader_dict = GPTQLLMDataLoaderDict()
189+
190+
quant_config = GPTQConfig()
191+
quant_config.set_local("lm_head", GPTQConfig(dtype="fp32"))
192+
193+
gptq_use_max_length = False
194+
gptq_max_seq_length = 2048
195+
dataloaderPreprocessor = GPTQDataloaderPreprocessor(
196+
dataloader_original=dataloader_list,
197+
use_max_length=gptq_use_max_length,
198+
max_seq_length=gptq_max_seq_length,
199+
)
200+
dataloader_for_calibration = dataloaderPreprocessor.get_prepared_dataloader()
201+
202+
def run_fn_for_gptq(model, dataloader_for_calibration, *args):
203+
for batch in tqdm(dataloader_for_calibration):
204+
batch = move_input_to_device(batch, device=model.device)
205+
try:
206+
if isinstance(batch, tuple) or isinstance(batch, list):
207+
model(batch[0])
208+
elif isinstance(batch, dict):
209+
model(**batch)
210+
else:
211+
model(batch)
212+
except ValueError:
213+
pass
214+
return
215+
216+
user_model = transformers.AutoModelForCausalLM.from_pretrained(
217+
"hf-internal-testing/tiny-random-GPTJForCausalLM",
218+
)
219+
220+
user_model = quantize(
221+
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration
222+
)
223+
model_device = str(user_model.device)
224+
assert "cuda" in model_device, f"Model device is {model_device}"
225+
226+
227+
class TestRTNQuant:
228+
229+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
230+
def test_rtn(self):
231+
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
232+
"hf-internal-testing/tiny-random-GPTJForCausalLM",
233+
)
234+
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
235+
model = self.tiny_gptj
236+
# record label for comparison
237+
self.label = model(self.example_inputs.to(model.device))[0]
238+
# test_default_config
239+
quant_config = get_default_rtn_config()
240+
q_model = quantize(model, quant_config)
241+
assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}"

0 commit comments

Comments
 (0)