Skip to content

Commit 5343009

Browse files
authored
Support auto device for TEQ and AWQ (#1634)
Signed-off-by: yiliu30 <[email protected]>
1 parent 7e1fa90 commit 5343009

File tree

4 files changed

+177
-14
lines changed

4 files changed

+177
-14
lines changed

neural_compressor/torch/algorithms/weight_only/awq.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
# Copied from neural_compressor/adaptor/torch_utils/awq.py
1616

1717
import copy
18-
from functools import partial
1918

2019
import torch
2120

22-
from neural_compressor.torch.utils import logger
21+
from neural_compressor.torch.utils import get_device, logger
2322

2423
from .modules import MulLinear
2524
from .utility import (
@@ -33,6 +32,8 @@
3332
set_module,
3433
)
3534

35+
__all__ = ["awq_quantize"]
36+
3637

3738
def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}):
3839
"""Get absorbed layer per block.
@@ -122,10 +123,13 @@ def __init__(
122123
use_full_range=False,
123124
weight_config={},
124125
):
126+
125127
self.example_inputs = example_inputs
128+
self.model = model
126129
if example_inputs is None:
127130
assert dataloader is not None, "datalaoder or example_inputs is required."
128131
self.example_inputs = get_example_input(dataloader)
132+
self._move_model_and_data_to_device()
129133
# Step 1: get hidden states and kwargs of first block.
130134
self.total_block_args, self.total_block_kwargs = get_hidden_states(
131135
model, dataloader=dataloader, n_samples=n_samples, calib_func=calib_func
@@ -139,7 +143,12 @@ def __init__(
139143
self.scheme = scheme
140144
self.use_full_range = use_full_range
141145
self.weight_config = weight_config
142-
self.model = model
146+
147+
def _move_model_and_data_to_device(self):
148+
# Put the model and example_inputs into target device
149+
device = get_device()
150+
self.model.to(device)
151+
self.example_inputs = self.example_inputs.to(device)
143152

144153
def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, return_int=False):
145154
"""Execute AWQ quantization.

neural_compressor/torch/algorithms/weight_only/teq.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
import torch
2020
import transformers
2121

22-
from neural_compressor.torch.utils import logger
22+
from neural_compressor.torch.utils import get_device, logger
2323

2424
from .modules import MulLinear, TEQLinearFakeQuant
2525
from .utility import get_module, quant_tensor, set_module
2626

27+
__all__ = ["teq_quantize", "TEQuantizer"]
28+
2729

2830
class TEQuantizer:
2931
"""Weight-only quantization, Trainable Equivalent Transformation (TEQ): linear wrapper to apply scale to input."""
@@ -38,16 +40,22 @@ def __init__(self, model, weight_config={}, absorb_to_layer={}, folding=True, ex
3840
self.weight_config = weight_config
3941
self.folding = folding
4042
self.example_inputs = example_inputs
41-
self.device, self.dtype = self._get_device()
43+
self.device = self._get_device()
44+
self.dtype = self._get_dtype()
4245
self.model.eval()
4346
self.trained_alphas = {}
4447
self.absorb_to_layer = absorb_to_layer
4548

4649
def _get_device(self):
4750
"""Get the model device
4851
:return:Model device."""
52+
device = get_device()
53+
self.model.to(device)
54+
return device
55+
56+
def _get_dtype(self):
4957
for _, p in self.model.named_parameters():
50-
return p.data.device, p.data.dtype
58+
return p.data.dtype
5159

5260
def add_tuning_scale(self, sqrt_w_init=False):
5361
"""The main entry of smooth quant

neural_compressor/torch/utils/auto_accelerator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
# NOTICE: The design adapted from:
2121
# https://github.com/microsoft/DeepSpeed/blob/master/accelerator/abstract_accelerator.py.
22-
# TODO: move it into torch/utils
2322

2423

2524
# To keep it simply, only add the APIs we need.
@@ -204,19 +203,27 @@ def empty_cache(self):
204203

205204

206205
def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
206+
# Force use the cpu on node has both cpu and gpu: `FORCE_DEVICE=cpu` python main.py ...
207+
# The `FORCE_DEVICE` is case insensitive.
207208
# The environment variable `FORCE_DEVICE` has higher priority than the `device_name`.
208209
# TODO: refine the docs and logic later
210+
# 1. Get the device setting from environment variable `FORCE_DEVICE`.
209211
FORCE_DEVICE = os.environ.get("FORCE_DEVICE", None)
212+
if FORCE_DEVICE:
213+
FORCE_DEVICE = FORCE_DEVICE.lower()
214+
# 2. If the `FORCE_DEVICE` is set and the accelerator is available, use it.
210215
if FORCE_DEVICE and accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE) is not None:
211216
logger.warning("Force use %s accelerator.", FORCE_DEVICE)
212217
return accelerator_registry.get_accelerator_cls_by_name(FORCE_DEVICE)()
218+
# 3. If the `device_name` is set and the accelerator is available, use it.
213219
if device_name != "auto":
214220
if accelerator_registry.get_accelerator_cls_by_name(device_name) is not None:
215221
accelerator_cls = accelerator_registry.get_accelerator_cls_by_name(device_name)
216222
logger.warning("Selected accelerator %s by device_name.", accelerator_cls.__name__)
217223
return accelerator_cls()
218224
else:
219225
logger.warning("The device name %s is not supported, use auto detect instead.", device_name)
226+
# 4. Select the accelerator by priority.
220227
for accelerator_cls in accelerator_registry.get_sorted_accelerators():
221228
if accelerator_cls.is_available():
222229
logger.warning("Auto detect accelerator: %s.", accelerator_cls.__name__)
@@ -227,4 +234,6 @@ def auto_detect_accelerator(device_name="auto") -> Auto_Accelerator:
227234
# Force use cpu accelerator even if cuda is available.
228235
# FORCE_DEVICE = "cpu" python ...
229236
# or
237+
# FORCE_DEVICE = "CPU" python ...
238+
# or
230239
# CUDA_VISIBLE_DEVICES="" python ...

test/3x/torch/quantization/weight_only/test_woq_on_cuda.py

Lines changed: 144 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,22 @@
77

88
from neural_compressor.common.utils import logger
99
from neural_compressor.torch.algorithms.weight_only.gptq import move_input_to_device
10-
from neural_compressor.torch.quantization import GPTQConfig, get_default_rtn_config, quantize
10+
from neural_compressor.torch.quantization import (
11+
AWQConfig,
12+
GPTQConfig,
13+
get_default_awq_config,
14+
get_default_rtn_config,
15+
get_default_teq_config,
16+
quantize,
17+
)
18+
19+
20+
def get_gpt_j():
21+
tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
22+
"hf-internal-testing/tiny-random-GPTJForCausalLM",
23+
torchscript=True,
24+
)
25+
return tiny_gptj
1126

1227

1328
class GPTQDataloaderPreprocessor:
@@ -213,9 +228,7 @@ def run_fn_for_gptq(model, dataloader_for_calibration, *args):
213228
pass
214229
return
215230

216-
user_model = transformers.AutoModelForCausalLM.from_pretrained(
217-
"hf-internal-testing/tiny-random-GPTJForCausalLM",
218-
)
231+
user_model = get_gpt_j()
219232

220233
user_model = quantize(
221234
model=user_model, quant_config=quant_config, run_fn=run_fn_for_gptq, run_args=dataloader_for_calibration
@@ -228,9 +241,7 @@ class TestRTNQuant:
228241

229242
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
230243
def test_rtn(self):
231-
self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained(
232-
"hf-internal-testing/tiny-random-GPTJForCausalLM",
233-
)
244+
self.tiny_gptj = get_gpt_j()
234245
self.example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long)
235246
model = self.tiny_gptj
236247
# record label for comparison
@@ -239,3 +250,129 @@ def test_rtn(self):
239250
quant_config = get_default_rtn_config()
240251
q_model = quantize(model, quant_config)
241252
assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}"
253+
254+
255+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
256+
class TestAWQOnCuda:
257+
258+
def test_awq(self):
259+
self.lm_input = torch.ones([1, 10], dtype=torch.long)
260+
self.gptj = get_gpt_j()
261+
example_inputs = torch.ones([1, 10], dtype=torch.long)
262+
263+
def calib_func(model):
264+
for i in range(2):
265+
model(self.lm_input.to(model.device))
266+
267+
quant_config = get_default_awq_config()
268+
logger.info("Test quantization with config", quant_config)
269+
q_model = quantize(
270+
model=self.gptj, quant_config=quant_config, example_inputs=self.lm_input, run_fn=calib_func, inplace=False
271+
)
272+
out2 = q_model(example_inputs.to(q_model.device))
273+
assert "cuda" in str(q_model.device), f"Expect qmodel device is cuda, got {q_model.device}"
274+
assert "cuda" in str(out2[0].device), f"Expect out2 device is cuda, got {out2.device}"
275+
276+
277+
def generate_random_corpus(nsamples=32):
278+
meta_data = []
279+
for _ in range(nsamples):
280+
inp = torch.ones([1, 512], dtype=torch.long)
281+
tar = torch.ones([1, 512], dtype=torch.long)
282+
meta_data.append((inp, tar))
283+
return meta_data
284+
285+
286+
def train(
287+
model,
288+
train_steps=1000,
289+
lr=1e-3,
290+
warmup_ratio=0.05,
291+
gradient_accumulation_steps=1,
292+
logging_steps=10,
293+
betas=[0.9, 0.9],
294+
weight_decay=0,
295+
lr_scheduler_type="linear",
296+
):
297+
"""Train function."""
298+
trained_alphas_list = [torch.ones([128], requires_grad=True)]
299+
optimizer = torch.optim.Adam(trained_alphas_list, lr=lr, weight_decay=weight_decay, betas=betas)
300+
301+
lr_scheduler = transformers.get_scheduler( # pylint: disable=E1111
302+
name=lr_scheduler_type,
303+
optimizer=optimizer,
304+
num_warmup_steps=int(train_steps * warmup_ratio) // gradient_accumulation_steps,
305+
num_training_steps=train_steps // gradient_accumulation_steps,
306+
)
307+
308+
logger.info("start training")
309+
model.train()
310+
global_steps = 0
311+
dataloader = generate_random_corpus()
312+
while global_steps <= train_steps:
313+
for inputs in dataloader:
314+
if isinstance(inputs, torch.Tensor):
315+
input_id = inputs
316+
elif isinstance(inputs, dict):
317+
input_id = inputs["input_ids"]
318+
else:
319+
input_id = inputs[0]
320+
output = model(input_id.to(model.device), labels=input_id.to(model.device))
321+
loss = output[0] / gradient_accumulation_steps
322+
loss.backward()
323+
global_steps += 1
324+
325+
if global_steps % logging_steps == 0:
326+
logger.info("steps: {}, loss: {}".format(global_steps, loss.detach().cpu().item()))
327+
328+
if global_steps % gradient_accumulation_steps == 0:
329+
optimizer.step()
330+
optimizer.zero_grad()
331+
lr_scheduler.step()
332+
333+
if global_steps >= train_steps: # pragma: no cover
334+
break
335+
336+
logger.info("finish training")
337+
model.eval()
338+
return None
339+
340+
341+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU")
342+
class TestTEQOnCuda:
343+
344+
def test_teq(self):
345+
quant_config = {
346+
"teq": {
347+
"global": {
348+
"dtype": "fp32",
349+
},
350+
"local": {
351+
"transformer.h.0.mlp.fc_in": {
352+
"dtype": "int",
353+
"bits": 8,
354+
"group_size": -1,
355+
"use_sym": True,
356+
"folding": True,
357+
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
358+
},
359+
"transformer.h.0.mlp.fc_out": {
360+
"dtype": "int",
361+
"bits": 4,
362+
"group_size": 32,
363+
"use_sym": False,
364+
"folding": True,
365+
"absorb_to_layer": {"transformer.h.0.mlp.fc_in": ["transformer.h.0.mlp.fc_out"]},
366+
},
367+
},
368+
}
369+
}
370+
example_inputs = torch.ones([1, 512], dtype=torch.long)
371+
test_input = torch.ones([1, 512], dtype=torch.long)
372+
model = get_gpt_j()
373+
374+
qdq_model = quantize(model=model, quant_config=quant_config, run_fn=train, example_inputs=example_inputs)
375+
assert isinstance(qdq_model, torch.nn.Module), "Expect qdq_model is a torch module"
376+
out2 = qdq_model(test_input.to(qdq_model.device))
377+
assert "cuda" in str(qdq_model.device), f"Expect qmodel device is cuda, got {qdq_model.device}"
378+
assert "cuda" in str(out2[0].device), f"Expect out2 device is cuda, got {out2.device}"

0 commit comments

Comments
 (0)