Skip to content

Commit deb1ed5

Browse files
authored
Quantize weight with in-place mode in weight-only quantization (#1511)
Signed-off-by: Cheng, Penghui <[email protected]>
1 parent 5b2a887 commit deb1ed5

File tree

3 files changed

+98
-91
lines changed

3 files changed

+98
-91
lines changed

neural_compressor/adaptor/torch_utils/awq.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def search_scale(self, block, block_name, module_list, input_values):
244244
x_max = _get_act_scale(input_val)
245245
absorbed_modules = {_m: fetch_module(block, _m) for _m in module_name_list}
246246
# Step 4: collect origin output for MSE and state_dict for recover.
247-
org_stat = {_m: module.state_dict() for _m, module in absorbed_modules.items()}
247+
org_stat = {_m: copy.deepcopy(module.state_dict()) for _m, module in absorbed_modules.items()}
248248
if len(module_tuple) > 1:
249249
# use block inference for multi-modules
250250
org_out = self.block_inference(block)
@@ -364,7 +364,7 @@ def search_clip(self, block_name, module_list, input_values):
364364
# Step 2: update module name
365365
module = fetch_module(self.model, module_name)
366366
# Step 3: collect origin output for MSE and state_dict for recover.
367-
org_stat = module.state_dict()
367+
org_stat = copy.deepcopy(module.state_dict())
368368
org_out = self.module_inference(module, input_val)
369369
# Step 4: set different clip range for weight and compare the MSE loss.
370370
logger.info("Searching the best clip range with AWQ algorithm")

neural_compressor/adaptor/torch_utils/teq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def quantize(self):
294294
group_size = self.weight_config[n]["group_size"]
295295
scheme = self.weight_config[n]["scheme"]
296296
if isinstance(m, torch.nn.Linear): # pragma: no cover
297-
m.weight.data.copy_(quant_weight(m.weight, num_bits=num_bits, group_size=group_size, scheme=scheme))
297+
quant_weight(m.weight.data, num_bits=num_bits, group_size=group_size, scheme=scheme)
298298

299299
def save(self, save_scale_file="", save_state_dict_file=""):
300300
"""

neural_compressor/adaptor/torch_utils/weight_only.py

Lines changed: 95 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False):
8080
# get scale and update tensor
8181
scale = tensor.abs().max(1)[0] * quantile / max(allow_data)
8282
scale.unsqueeze_(dim=-1)
83-
tensor = tensor / scale
83+
tensor.div_(scale)
8484
mid_data = [(allow_data[i] + allow_data[i + 1]) / 2 for i in range(len(allow_data) - 1)]
8585
q_tensor = torch.zeros_like(tensor)
8686
for i in range(len(allow_data)):
@@ -91,9 +91,10 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False):
9191
q_tensor += torch.where(tensor > mid_data[i - 1], data, 0)
9292
else:
9393
q_tensor += torch.where((mid_data[i - 1] < tensor) & (tensor <= mid_data[i]), data, 0)
94+
tensor.copy_(q_tensor)
9495
if return_int:
95-
return q_tensor.type(torch.int8), scale.type(torch.float), None
96-
return q_tensor * scale
96+
return tensor.type(torch.int8), scale.type(torch.float), None
97+
return tensor.mul_(scale)
9798

9899

99100
def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False):
@@ -122,10 +123,14 @@ def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False):
122123
zp = torch.round(-wmin / scale)
123124
scale.unsqueeze_(dim=-1)
124125
zp.unsqueeze_(dim=-1)
125-
q = torch.clamp(torch.round(weight / scale) + zp, 0, maxq)
126+
weight.div_(scale)
127+
weight.round_()
128+
weight.add_(zp)
129+
weight.clamp_(0, maxq)
126130
if return_int:
127-
return q.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8)
128-
return scale * (q - zp)
131+
return weight.type(torch.uint8), scale.type(torch.float), zp.type(torch.uint8)
132+
weight.sub_(zp)
133+
return weight.mul_(scale)
129134

130135

131136
def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_range=False):
@@ -167,10 +172,12 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang
167172
else:
168173
scale = wmax / maxq
169174
scale.unsqueeze_(dim=-1)
170-
q = torch.clamp(torch.round(weight / scale), minq, maxq)
175+
weight.div_(scale)
176+
weight.round_()
177+
weight.clamp_(minq, maxq)
171178
if return_int:
172-
return q.type(torch.int8), scale.type(torch.float), None
173-
return scale * q
179+
return weight.type(torch.int8), scale.type(torch.float), None
180+
return weight.mul_(scale)
174181

175182

176183
def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", return_int=False, full_range=False):
@@ -200,7 +207,7 @@ def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", re
200207
def quant_weight(
201208
weight, num_bits=4, group_size=-1, scheme="asym", quantile=1.0, data_type="int", return_int=False, full_range=False
202209
):
203-
"""Quant and dequant tensor with group size.
210+
"""Quant and dequant tensor with group size. It is an in-place op.
204211
205212
Args:
206213
weight: input weight
@@ -248,7 +255,7 @@ def quant_weight(
248255
zp = zp.reshape(orig_shape[0], -1)
249256
return weight, scale, zp
250257
else:
251-
weight = qdq_weight_actor(
258+
qdq_weight_actor(
252259
weight, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
253260
)
254261
return weight.reshape(orig_shape)
@@ -285,7 +292,6 @@ def quant_weight(
285292
return_int=True,
286293
full_range=full_range,
287294
)
288-
weight = torch.cat([weight1, weight2], dim=1)
289295
scale = torch.cat([scale1, scale2], dim=1)
290296
if zp2 is not None:
291297
zp = torch.cat([zp1, zp2], dim=1)
@@ -296,7 +302,6 @@ def quant_weight(
296302
weight2 = qdq_weight_actor(
297303
weight2, num_bits, scheme=scheme, data_type=data_type, quantile=quantile, full_range=full_range
298304
)
299-
weight = torch.cat([weight1, weight2], dim=1)
300305
return weight
301306

302307

@@ -314,7 +319,7 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en
314319
Returns:
315320
best_clip_ratio (float): best percentile of clip
316321
"""
317-
org_weight = m.weight.data
322+
org_weight = m.weight.data.clone()
318323
logger.info("Searching the best clip range with RTN algorithm")
319324
best_error = float("inf")
320325
best_clip_ratio = None
@@ -397,82 +402,84 @@ def rtn_quantize(
397402
scale_dtype = kwargs.get("scale_dtype", torch.float32)
398403
device = kwargs.get("device", "cpu")
399404
use_optimum_format = kwargs.get("use_optimum_format", True)
400-
for name, m in model.named_modules():
401-
if m.__class__.__name__ not in supported_layers:
402-
continue
403-
orig_dtype = next(m.parameters()).dtype
404-
if orig_dtype != torch.float:
405-
m = m.float()
406-
if name in weight_config: # pragma: no cover
407-
num_bits = weight_config[name]["bits"]
408-
group_size = weight_config[name]["group_size"]
409-
scheme = weight_config[name]["scheme"]
410-
quantile = weight_config[name].get("quantile", 1.0)
411-
logger.debug(f"RTN quantized module:{name, m}")
412-
log_msg = (
413-
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
414-
+ f"scheme={scheme}, quantile={quantile}"
415-
)
416-
if data_type != "int":
417-
log_msg += f", dtype={data_type}"
418-
elif scheme == "sym": # nf4/fp4 is always [-7,7]
419-
log_msg += f", enable_full_range={enable_full_range}"
420-
logger.debug(log_msg)
421-
if num_bits <= 0:
422-
logger.info(f"Skip {name}")
423-
continue
424-
weight = m.weight.T if group_dim == 0 else m.weight
425-
if enable_mse_search:
426-
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
427-
if return_int:
428-
from .model_wrapper import WeightOnlyLinear
429-
430-
int_weight, scale, zp = quant_weight(
431-
weight,
432-
num_bits,
433-
group_size,
434-
scheme,
435-
quantile,
436-
data_type=data_type,
437-
return_int=True,
438-
full_range=enable_full_range,
405+
with torch.no_grad():
406+
for name, m in model.named_modules():
407+
if m.__class__.__name__ not in supported_layers:
408+
continue
409+
orig_dtype = next(m.parameters()).dtype
410+
if orig_dtype != torch.float:
411+
m = m.float()
412+
if name in weight_config: # pragma: no cover
413+
num_bits = weight_config[name]["bits"]
414+
group_size = weight_config[name]["group_size"]
415+
scheme = weight_config[name]["scheme"]
416+
quantile = weight_config[name].get("quantile", 1.0)
417+
logger.debug(f"RTN quantized module:{name, m}")
418+
log_msg = (
419+
f"RTN quantization config: num_bits={num_bits}, group_size={group_size}, "
420+
+ f"scheme={scheme}, quantile={quantile}"
439421
)
440-
int_weight = int_weight.T if group_dim == 0 else int_weight
441-
scale = scale.T if group_dim == 0 else scale
442-
zp = zp.T if group_dim == 0 and zp is not None else zp
443-
new_module = WeightOnlyLinear(
444-
m.in_features,
445-
m.out_features,
446-
num_bits,
447-
group_size,
448-
dtype=data_type,
449-
zp=zp is not None,
450-
bias=m.bias is not None,
451-
compression_dtype=compression_dtype,
452-
compression_dim=compression_dim,
453-
scale_dtype=scale_dtype,
454-
device=device,
455-
use_optimum_format=use_optimum_format,
456-
)
457-
new_module.pack(int_weight, scale, zp, m.bias)
458-
if name == "":
459-
return new_module
422+
if data_type != "int":
423+
log_msg += f", dtype={data_type}"
424+
elif scheme == "sym": # nf4/fp4 is always [-7,7]
425+
log_msg += f", enable_full_range={enable_full_range}"
426+
logger.debug(log_msg)
427+
if num_bits <= 0:
428+
logger.info(f"Skip {name}")
429+
continue
430+
weight = m.weight.T if group_dim == 0 else m.weight
431+
if enable_mse_search:
432+
quantile = search_clip(m, num_bits, group_size, scheme, data_type, enable_full_range)
433+
if return_int:
434+
from .model_wrapper import WeightOnlyLinear
435+
436+
_, scale, zp = quant_weight(
437+
weight,
438+
num_bits,
439+
group_size,
440+
scheme,
441+
quantile,
442+
data_type=data_type,
443+
return_int=True,
444+
full_range=enable_full_range,
445+
)
446+
if group_dim == 0:
447+
weight.transpose_(0, 1)
448+
scale = scale.T if group_dim == 0 else scale
449+
zp = zp.T if group_dim == 0 and zp is not None else zp
450+
new_module = WeightOnlyLinear(
451+
m.in_features,
452+
m.out_features,
453+
num_bits,
454+
group_size,
455+
dtype=data_type,
456+
zp=zp is not None,
457+
bias=m.bias is not None,
458+
compression_dtype=compression_dtype,
459+
compression_dim=compression_dim,
460+
scale_dtype=scale_dtype,
461+
device=device,
462+
use_optimum_format=use_optimum_format,
463+
)
464+
new_module.pack(weight, scale, zp, m.bias)
465+
if name == "":
466+
return new_module
467+
else:
468+
set_module(model, name, new_module)
460469
else:
461-
set_module(model, name, new_module)
462-
else:
463-
q_weight = quant_weight(
464-
weight,
465-
num_bits,
466-
group_size,
467-
scheme,
468-
quantile,
469-
data_type=data_type,
470-
full_range=enable_full_range,
471-
)
472-
q_weight = q_weight.T if group_dim == 0 else q_weight
473-
m.weight.data.copy_(q_weight)
474-
if orig_dtype != torch.float:
475-
m = m.to(orig_dtype)
470+
quant_weight(
471+
weight,
472+
num_bits,
473+
group_size,
474+
scheme,
475+
quantile,
476+
data_type=data_type,
477+
full_range=enable_full_range,
478+
)
479+
if group_dim == 0:
480+
weight.transpose_(0, 1)
481+
if orig_dtype != torch.float:
482+
m = m.to(orig_dtype)
476483
return model
477484

478485

0 commit comments

Comments
 (0)