1919# since the model classes inherit torch.nn.Module.
2020import math
2121
22+ import numpy as np
2223import torch
2324from packaging .version import Version
2425from torch .autograd import Function
@@ -325,11 +326,89 @@ def __init__(
325326 else :
326327 self .g_idx = None
327328
329+ def pack_tensor_with_numpy (self , raw_tensor ):
330+ raw_array = raw_tensor .cpu ().numpy ()
331+ target_len = np .ceil (raw_array .shape [1 ] / self .n_pack ).astype (int )
332+ target_dtype = torch .tensor (0 , dtype = self .compression_dtype ).numpy ().dtype
333+ packed_array = np .zeros ((raw_array .shape [0 ], target_len ), dtype = target_dtype )
334+ mask = np .uint8 (2 ** self .bits - 1 )
335+ for j in range (packed_array .shape [1 ]):
336+ start = self .n_pack * j
337+ end = self .n_pack * (j + 1 )
338+ tmp = raw_array [:, start :end ].astype (target_dtype )
339+ tmp &= mask
340+ for e in range (tmp .shape [1 ]):
341+ tmp [:, e ] = np .left_shift (tmp [:, e ], self .bits * e )
342+ packed_array [:, j ] |= tmp [:, e ]
343+ packed_tensor = torch .from_numpy (packed_array ).to (device = raw_tensor .device )
344+ return packed_tensor
345+
346+ def unpack_tensor_with_numpy (self , packed_tensor ):
347+ packed_array = packed_tensor .cpu ().numpy ()
348+ target_dtype = np .int8 if not hasattr (self , "qzeros" ) or "int" not in self .dtype else np .uint8
349+ target_len = packed_array .shape [1 ] * self .n_pack
350+ unpacked_array = np .zeros ((packed_array .shape [0 ], target_len ), dtype = target_dtype )
351+ mask = np .uint8 (2 ** self .bits - 1 )
352+ for j in range (packed_array .shape [1 ]):
353+ for e in range (self .n_pack ):
354+ index = j * self .n_pack + e
355+ tmp = packed_array [:, j ]
356+ tmp = np .left_shift (tmp , self .compress_bits - self .bits * (e + 1 ))
357+ tmp = np .right_shift (tmp , self .compress_bits - self .bits )
358+ if target_dtype == np .uint8 :
359+ tmp &= mask
360+ unpacked_array [:, index ] = tmp .astype (target_dtype )
361+ unpacked_tensor = torch .from_numpy (unpacked_array ).to (device = packed_tensor .device )
362+ return unpacked_tensor
363+
364+ def pack_tensor_with_torch (self , raw_tensor ):
365+ target_len = math .ceil (raw_tensor .shape [1 ] / self .n_pack )
366+ packed_tensor = torch .zeros (raw_tensor .shape [0 ], target_len , dtype = self .compression_dtype ).to (self .device )
367+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
368+ for j in range (packed_tensor .shape [1 ]):
369+ start = self .n_pack * j
370+ end = self .n_pack * (j + 1 )
371+ tmp = raw_tensor [:, start :end ].type (self .compression_dtype )
372+ tmp &= mask
373+ for e in range (tmp .shape [1 ]):
374+ tmp [:, e ] = tmp [:, e ] << (self .bits * e )
375+ packed_tensor [:, j ] |= tmp [:, e ]
376+ return packed_tensor
377+
378+ def unpack_tensor_with_torch (self , packed_tensor ):
379+ target_dtype = torch .int8 if not hasattr (self , "qzeros" ) or "int" not in self .dtype else torch .uint8
380+ target_len = packed_tensor .shape [1 ] * self .n_pack
381+ unpacked_tensor = torch .zeros (packed_tensor .shape [0 ], target_len , dtype = target_dtype ).to (self .device )
382+ mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
383+ for j in range (packed_tensor .shape [1 ]):
384+ for e in range (self .n_pack ):
385+ index = j * self .n_pack + e
386+ tmp = packed_tensor [:, j ]
387+ tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
388+ tmp = tmp >> self .compress_bits - self .bits
389+ if target_dtype == torch .uint8 :
390+ tmp &= mask # remove sign bit
391+ unpacked_tensor [:, index ].copy_ (tmp .type (target_dtype ))
392+ logger .info (f"*****{ unpacked_tensor } " )
393+ return unpacked_tensor
394+
395+ def pack_tensor (self , raw_tensor ):
396+ if "cuda" in self .device :
397+ return self .pack_tensor_with_torch (raw_tensor )
398+ else :
399+ return self .pack_tensor_with_numpy (raw_tensor )
400+
401+ def unpack_tensor (self , packed_tensor ):
402+ if "cuda" in self .device :
403+ return self .unpack_tensor_with_torch (packed_tensor )
404+ else :
405+ return self .unpack_tensor_with_numpy (packed_tensor )
406+
328407 def pack (self , int_weight , scale , zp , bias , g_idx = None ):
329408 if self .use_optimum_format :
330- self .scales = self .scales .t_ () .contiguous ()
331- self .qweight = self .qweight .t_ () .contiguous ()
332- self .qzeros = self .qzeros .t_ () .contiguous ()
409+ self .scales = self .scales .T .contiguous ()
410+ self .qweight = self .qweight .T .contiguous ()
411+ self .qzeros = self .qzeros .T .contiguous ()
333412 int_weight = int_weight .to (self .device )
334413 if self .use_optimum_format and zp is None :
335414 # to avoid overflow
@@ -350,118 +429,73 @@ def pack(self, int_weight, scale, zp, bias, g_idx=None):
350429 assert scale .shape == self .scales .shape , "Scale shape is mismatched."
351430 self .scales = scale .type (self .float_type ).to (self .device )
352431 if not self .use_optimum_format and self .compression_dim == 0 :
353- int_weight = int_weight .t_ () .contiguous ()
354- self .qweight = self .qweight .t_ () .contiguous ()
432+ int_weight = int_weight .T .contiguous ()
433+ self .qweight = self .qweight .T .contiguous ()
355434 origin_shape = int_weight .shape
356435 target_shape = self .qweight .shape
357436 assert origin_shape [0 ] == target_shape [0 ], "output channels mismatch, please check."
358- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (self .device )
359437
360438 # pack weight
361- for j in range (target_shape [1 ]):
362- start = self .n_pack * j
363- end = self .n_pack * (j + 1 )
364- tmp = int_weight [:, start :end ].type (self .compression_dtype )
365- for e in range (tmp .shape [1 ]):
366- tmp [:, e ] &= mask
367- tmp [:, e ] = tmp [:, e ] << (self .bits * e )
368- self .qweight [:, j ] |= tmp [:, e ]
439+ self .qweight .copy_ (self .pack_tensor (int_weight ))
369440 if not self .use_optimum_format and self .compression_dim == 0 :
370- self .qweight = self .qweight .t_ () .contiguous ()
441+ self .qweight = self .qweight .T .contiguous ()
371442
372443 if zp is not None :
373444 zp = zp .to (self .device )
374445 if self .use_optimum_format :
375446 zp -= 1
376447 if self .use_optimum_format or self .compression_dim == 0 :
377- zp = zp .t_ () .contiguous ()
378- self .qzeros = self .qzeros .t_ () .contiguous ()
448+ zp = zp .T .contiguous ()
449+ self .qzeros = self .qzeros .T .contiguous ()
379450 assert hasattr (self , "qzeros" ), "zp is not set when initializing."
380- target_shape = self .qzeros .shape
381- for j in range (target_shape [1 ]):
382- start = self .n_pack * j
383- end = self .n_pack * (j + 1 )
384- tmp = zp [:, start :end ].type (self .compression_dtype )
385- for e in range (tmp .shape [1 ]):
386- tmp [:, e ] &= mask
387- tmp [:, e ] = tmp [:, e ] << (self .bits * e )
388- self .qzeros [:, j ] |= tmp [:, e ]
451+ self .qzeros .copy_ (self .pack_tensor (zp ))
389452 if self .use_optimum_format or self .compression_dim == 0 :
390- self .qzeros = self .qzeros .t_ () .contiguous ()
453+ self .qzeros = self .qzeros .T .contiguous ()
391454 if self .use_optimum_format :
392- self .scales = self .scales .t_ () .contiguous ()
393- self .qweight = self .qweight .t_ () .contiguous ()
394- self .qzeros = self .qzeros .t_ () .contiguous ()
455+ self .scales = self .scales .T .contiguous ()
456+ self .qweight = self .qweight .T .contiguous ()
457+ self .qzeros = self .qzeros .T .contiguous ()
395458
396459 def recover (self ):
397460 logger .debug (f"Recovering { self } weight" )
398- scales = self .scales .t_ () .contiguous () if self .use_optimum_format else self .scales
399- qweight = self .qweight .t_ () .contiguous () if self .use_optimum_format else self .qweight
461+ scales = self .scales .T .contiguous () if self .use_optimum_format else self .scales
462+ qweight = self .qweight .T .contiguous () if self .use_optimum_format else self .qweight
400463
401464 device = scales .device
402465 fp32_weight = torch .zeros (self .out_features , self .in_features , dtype = self .float_type ).to (device )
403466 if self .g_idx is None :
404467 # used for recovering fp32_weight
405468 self .g_idx = torch .tensor ([i // self .groupsize for i in range (self .in_features )], dtype = torch .int32 )
406- mask = torch .tensor (2 ** self .bits - 1 , dtype = self .compression_dtype ).to (device )
407- if hasattr (self , "qzeros" ):
408- weight_dtype = torch .uint8
409- else :
410- weight_dtype = torch .int8
411469 # unpack weight
412- weight = torch .zeros (self .out_features , self .in_features , dtype = weight_dtype ).to (device )
413470 if not self .use_optimum_format and self .compression_dim == 0 :
414- weight = weight .t_ ().contiguous ()
415- qweight = qweight .t_ ().contiguous ()
416- origin_shape = weight .shape
417- target_shape = qweight .shape
418- for j in range (target_shape [1 ]):
419- for e in range (self .n_pack ):
420- index = j * self .n_pack + e
421- if index >= origin_shape [1 ]:
422- continue
423- tmp = qweight [:, j ]
424- tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
425- tmp = tmp >> self .compress_bits - self .bits
426- if weight_dtype == torch .uint8 :
427- tmp &= mask # remove sign bit
428- weight [:, index ] = tmp .type (weight_dtype )
471+ qweight = qweight .T .contiguous ()
472+ weight = self .unpack_tensor (qweight )
429473 if not self .use_optimum_format and self .compression_dim == 0 :
430- weight = weight .t_ ().contiguous ()
474+ weight = weight .T .contiguous ()
475+ weight = weight [: self .out_features , : self .in_features ] # avoid oversize
431476 if "int" not in self .dtype :
432477 new_weight = torch .zeros (self .out_features , self .in_features ).to (device )
433478 for k , v in self .int2float_mapping .items ():
434479 new_weight += torch .where (weight == k , v , 0 )
435480 weight = new_weight
436481 # unpack zero_point
437482 if hasattr (self , "qzeros" ):
438- zp_dtype = self .compression_dtype # to avoid overflow when weight-zp
439- zp = torch .zeros (scales .shape , dtype = zp_dtype ).to (device )
440- qzeros = self .qzeros .t_ ().contiguous () if self .use_optimum_format else self .qzeros
483+ qzeros = self .qzeros .T .contiguous () if self .use_optimum_format else self .qzeros
441484 if self .use_optimum_format or self .compression_dim == 0 :
442- zp = zp .t_ ().contiguous ()
443- qzeros = qzeros .t_ ().contiguous ()
444- origin_shape = zp .shape
445- target_shape = qzeros .shape
446- for j in range (target_shape [1 ]):
447- for e in range (self .n_pack ):
448- index = j * self .n_pack + e
449- if index >= origin_shape [1 ]:
450- continue
451- tmp = qzeros [:, j ]
452- tmp = tmp << (self .compress_bits - self .bits * (e + 1 ))
453- tmp = tmp >> self .compress_bits - self .bits
454- tmp &= mask
455- zp [:, index ] = tmp .type (zp_dtype )
485+ qzeros = qzeros .T .contiguous ()
486+ zp = self .unpack_tensor (qzeros )
456487 if self .use_optimum_format or self .compression_dim == 0 :
457- zp = zp .t_ ().contiguous ()
488+ zp = zp .T .contiguous ()
489+ zp = zp [: scales .shape [0 ], : scales .shape [1 ]] # avoid oversize
458490 if self .use_optimum_format :
459491 # zp -= 1 may cause zp == -1, after recover it becomes 2**self.bits - 1
460492 zp += 1
461493 zp = torch .where (zp > (2 ** self .bits - 1 ), 0 , zp )
462494 # recover fp32 weight with int_weight, scale, and zero_point
463495 for idx in range (self .in_features ):
464- fp32_weight [:, idx ] = (weight [:, idx ] - zp [:, self .g_idx [idx ]]) * scales [:, self .g_idx [idx ]]
496+ fp32_weight [:, idx ] = (torch .subtract (weight [:, idx ], zp [:, self .g_idx [idx ]]).to (torch .int8 )) * scales [
497+ :, self .g_idx [idx ]
498+ ]
465499 else :
466500 # recover fp32 weight with int_weight, scale
467501 for idx in range (self .in_features ):
0 commit comments