@@ -80,7 +80,7 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False):
8080 # get scale and update tensor
8181 scale = tensor .abs ().max (1 )[0 ] * quantile / max (allow_data )
8282 scale .unsqueeze_ (dim = - 1 )
83- tensor = tensor / scale
83+ tensor . div_ ( scale )
8484 mid_data = [(allow_data [i ] + allow_data [i + 1 ]) / 2 for i in range (len (allow_data ) - 1 )]
8585 q_tensor = torch .zeros_like (tensor )
8686 for i in range (len (allow_data )):
@@ -91,9 +91,10 @@ def quantize_4bit(tensor, quantile=1.0, data_type="nf4", return_int=False):
9191 q_tensor += torch .where (tensor > mid_data [i - 1 ], data , 0 )
9292 else :
9393 q_tensor += torch .where ((mid_data [i - 1 ] < tensor ) & (tensor <= mid_data [i ]), data , 0 )
94+ tensor .copy_ (q_tensor )
9495 if return_int :
95- return q_tensor .type (torch .int8 ), scale .type (torch .float ), None
96- return q_tensor * scale
96+ return tensor .type (torch .int8 ), scale .type (torch .float ), None
97+ return tensor . mul_ ( scale )
9798
9899
99100def qdq_weight_asym (weight , num_bits = 4 , quantile = 1.0 , return_int = False ):
@@ -122,10 +123,14 @@ def qdq_weight_asym(weight, num_bits=4, quantile=1.0, return_int=False):
122123 zp = torch .round (- wmin / scale )
123124 scale .unsqueeze_ (dim = - 1 )
124125 zp .unsqueeze_ (dim = - 1 )
125- q = torch .clamp (torch .round (weight / scale ) + zp , 0 , maxq )
126+ weight .div_ (scale )
127+ weight .round_ ()
128+ weight .add_ (zp )
129+ weight .clamp_ (0 , maxq )
126130 if return_int :
127- return q .type (torch .uint8 ), scale .type (torch .float ), zp .type (torch .uint8 )
128- return scale * (q - zp )
131+ return weight .type (torch .uint8 ), scale .type (torch .float ), zp .type (torch .uint8 )
132+ weight .sub_ (zp )
133+ return weight .mul_ (scale )
129134
130135
131136def qdq_weight_sym (weight , num_bits = 4 , quantile = 1.0 , return_int = False , full_range = False ):
@@ -167,10 +172,12 @@ def qdq_weight_sym(weight, num_bits=4, quantile=1.0, return_int=False, full_rang
167172 else :
168173 scale = wmax / maxq
169174 scale .unsqueeze_ (dim = - 1 )
170- q = torch .clamp (torch .round (weight / scale ), minq , maxq )
175+ weight .div_ (scale )
176+ weight .round_ ()
177+ weight .clamp_ (minq , maxq )
171178 if return_int :
172- return q .type (torch .int8 ), scale .type (torch .float ), None
173- return scale * q
179+ return weight .type (torch .int8 ), scale .type (torch .float ), None
180+ return weight . mul_ ( scale )
174181
175182
176183def qdq_weight_actor (weight , num_bits , scheme , quantile = 1.0 , data_type = "int" , return_int = False , full_range = False ):
@@ -200,7 +207,7 @@ def qdq_weight_actor(weight, num_bits, scheme, quantile=1.0, data_type="int", re
200207def quant_weight (
201208 weight , num_bits = 4 , group_size = - 1 , scheme = "asym" , quantile = 1.0 , data_type = "int" , return_int = False , full_range = False
202209):
203- """Quant and dequant tensor with group size.
210+ """Quant and dequant tensor with group size. It is an in-place op.
204211
205212 Args:
206213 weight: input weight
@@ -248,7 +255,7 @@ def quant_weight(
248255 zp = zp .reshape (orig_shape [0 ], - 1 )
249256 return weight , scale , zp
250257 else :
251- weight = qdq_weight_actor (
258+ qdq_weight_actor (
252259 weight , num_bits , scheme = scheme , data_type = data_type , quantile = quantile , full_range = full_range
253260 )
254261 return weight .reshape (orig_shape )
@@ -285,7 +292,6 @@ def quant_weight(
285292 return_int = True ,
286293 full_range = full_range ,
287294 )
288- weight = torch .cat ([weight1 , weight2 ], dim = 1 )
289295 scale = torch .cat ([scale1 , scale2 ], dim = 1 )
290296 if zp2 is not None :
291297 zp = torch .cat ([zp1 , zp2 ], dim = 1 )
@@ -296,7 +302,6 @@ def quant_weight(
296302 weight2 = qdq_weight_actor (
297303 weight2 , num_bits , scheme = scheme , data_type = data_type , quantile = quantile , full_range = full_range
298304 )
299- weight = torch .cat ([weight1 , weight2 ], dim = 1 )
300305 return weight
301306
302307
@@ -314,7 +319,7 @@ def search_clip(m, num_bits=4, group_size=32, scheme="asym", data_type="int", en
314319 Returns:
315320 best_clip_ratio (float): best percentile of clip
316321 """
317- org_weight = m .weight .data
322+ org_weight = m .weight .data . clone ()
318323 logger .info ("Searching the best clip range with RTN algorithm" )
319324 best_error = float ("inf" )
320325 best_clip_ratio = None
@@ -397,82 +402,84 @@ def rtn_quantize(
397402 scale_dtype = kwargs .get ("scale_dtype" , torch .float32 )
398403 device = kwargs .get ("device" , "cpu" )
399404 use_optimum_format = kwargs .get ("use_optimum_format" , True )
400- for name , m in model .named_modules ():
401- if m .__class__ .__name__ not in supported_layers :
402- continue
403- orig_dtype = next (m .parameters ()).dtype
404- if orig_dtype != torch .float :
405- m = m .float ()
406- if name in weight_config : # pragma: no cover
407- num_bits = weight_config [name ]["bits" ]
408- group_size = weight_config [name ]["group_size" ]
409- scheme = weight_config [name ]["scheme" ]
410- quantile = weight_config [name ].get ("quantile" , 1.0 )
411- logger .debug (f"RTN quantized module:{ name , m } " )
412- log_msg = (
413- f"RTN quantization config: num_bits={ num_bits } , group_size={ group_size } , "
414- + f"scheme={ scheme } , quantile={ quantile } "
415- )
416- if data_type != "int" :
417- log_msg += f", dtype={ data_type } "
418- elif scheme == "sym" : # nf4/fp4 is always [-7,7]
419- log_msg += f", enable_full_range={ enable_full_range } "
420- logger .debug (log_msg )
421- if num_bits <= 0 :
422- logger .info (f"Skip { name } " )
423- continue
424- weight = m .weight .T if group_dim == 0 else m .weight
425- if enable_mse_search :
426- quantile = search_clip (m , num_bits , group_size , scheme , data_type , enable_full_range )
427- if return_int :
428- from .model_wrapper import WeightOnlyLinear
429-
430- int_weight , scale , zp = quant_weight (
431- weight ,
432- num_bits ,
433- group_size ,
434- scheme ,
435- quantile ,
436- data_type = data_type ,
437- return_int = True ,
438- full_range = enable_full_range ,
405+ with torch .no_grad ():
406+ for name , m in model .named_modules ():
407+ if m .__class__ .__name__ not in supported_layers :
408+ continue
409+ orig_dtype = next (m .parameters ()).dtype
410+ if orig_dtype != torch .float :
411+ m = m .float ()
412+ if name in weight_config : # pragma: no cover
413+ num_bits = weight_config [name ]["bits" ]
414+ group_size = weight_config [name ]["group_size" ]
415+ scheme = weight_config [name ]["scheme" ]
416+ quantile = weight_config [name ].get ("quantile" , 1.0 )
417+ logger .debug (f"RTN quantized module:{ name , m } " )
418+ log_msg = (
419+ f"RTN quantization config: num_bits={ num_bits } , group_size={ group_size } , "
420+ + f"scheme={ scheme } , quantile={ quantile } "
439421 )
440- int_weight = int_weight .T if group_dim == 0 else int_weight
441- scale = scale .T if group_dim == 0 else scale
442- zp = zp .T if group_dim == 0 and zp is not None else zp
443- new_module = WeightOnlyLinear (
444- m .in_features ,
445- m .out_features ,
446- num_bits ,
447- group_size ,
448- dtype = data_type ,
449- zp = zp is not None ,
450- bias = m .bias is not None ,
451- compression_dtype = compression_dtype ,
452- compression_dim = compression_dim ,
453- scale_dtype = scale_dtype ,
454- device = device ,
455- use_optimum_format = use_optimum_format ,
456- )
457- new_module .pack (int_weight , scale , zp , m .bias )
458- if name == "" :
459- return new_module
422+ if data_type != "int" :
423+ log_msg += f", dtype={ data_type } "
424+ elif scheme == "sym" : # nf4/fp4 is always [-7,7]
425+ log_msg += f", enable_full_range={ enable_full_range } "
426+ logger .debug (log_msg )
427+ if num_bits <= 0 :
428+ logger .info (f"Skip { name } " )
429+ continue
430+ weight = m .weight .T if group_dim == 0 else m .weight
431+ if enable_mse_search :
432+ quantile = search_clip (m , num_bits , group_size , scheme , data_type , enable_full_range )
433+ if return_int :
434+ from .model_wrapper import WeightOnlyLinear
435+
436+ _ , scale , zp = quant_weight (
437+ weight ,
438+ num_bits ,
439+ group_size ,
440+ scheme ,
441+ quantile ,
442+ data_type = data_type ,
443+ return_int = True ,
444+ full_range = enable_full_range ,
445+ )
446+ if group_dim == 0 :
447+ weight .transpose_ (0 , 1 )
448+ scale = scale .T if group_dim == 0 else scale
449+ zp = zp .T if group_dim == 0 and zp is not None else zp
450+ new_module = WeightOnlyLinear (
451+ m .in_features ,
452+ m .out_features ,
453+ num_bits ,
454+ group_size ,
455+ dtype = data_type ,
456+ zp = zp is not None ,
457+ bias = m .bias is not None ,
458+ compression_dtype = compression_dtype ,
459+ compression_dim = compression_dim ,
460+ scale_dtype = scale_dtype ,
461+ device = device ,
462+ use_optimum_format = use_optimum_format ,
463+ )
464+ new_module .pack (weight , scale , zp , m .bias )
465+ if name == "" :
466+ return new_module
467+ else :
468+ set_module (model , name , new_module )
460469 else :
461- set_module (model , name , new_module )
462- else :
463- q_weight = quant_weight (
464- weight ,
465- num_bits ,
466- group_size ,
467- scheme ,
468- quantile ,
469- data_type = data_type ,
470- full_range = enable_full_range ,
471- )
472- q_weight = q_weight .T if group_dim == 0 else q_weight
473- m .weight .data .copy_ (q_weight )
474- if orig_dtype != torch .float :
475- m = m .to (orig_dtype )
470+ quant_weight (
471+ weight ,
472+ num_bits ,
473+ group_size ,
474+ scheme ,
475+ quantile ,
476+ data_type = data_type ,
477+ full_range = enable_full_range ,
478+ )
479+ if group_dim == 0 :
480+ weight .transpose_ (0 , 1 )
481+ if orig_dtype != torch .float :
482+ m = m .to (orig_dtype )
476483 return model
477484
478485
0 commit comments