@@ -107,21 +107,25 @@ def hessian_memory_requirements(model: torch.nn.Module) -> int:
107
107
:return: number of bytes required to reserve for GPTQ on a single layer
108
108
"""
109
109
transformer_layers = get_layers (get_no_split_params (model ), model )
110
- single_layer = transformer_layers [list (transformer_layers .keys ())[0 ]]
111
- total_hessian_elems = 0
112
- max_column_size = 0
113
- for _ , module in single_layer .named_modules ():
114
- if isinstance (module , Linear ):
115
- for param in module .parameters ():
116
- column_size = param .shape [1 ]
117
- total_hessian_elems += column_size * column_size
118
- if column_size > max_column_size :
119
- # max extra memory for inverse calculation
120
- max_column_size = column_size
121
-
110
+ total_hessian_elems = {}
111
+ max_column_size = {}
112
+ for no_split_name , no_split_layer in transformer_layers .items ():
113
+ total_hessian_elems [no_split_name ] = 0
114
+ max_column_size [no_split_name ] = 0
115
+ for name , module in no_split_layer .named_modules ():
116
+ if isinstance (module , Linear ):
117
+ for param in module .parameters ():
118
+ column_size = param .shape [1 ]
119
+ total_hessian_elems [no_split_name ] += column_size * column_size
120
+ if column_size > max_column_size [no_split_name ]:
121
+ # max extra memory for inverse calculation
122
+ max_column_size [no_split_name ] = column_size
123
+
124
+ max_total_hessian_elems = max (total_hessian_elems .values ())
125
+ overall_max_column_size = max (max_column_size .values ())
122
126
bytes_per_weight = 32 // 8 # hessians are float32
123
- inverse_reserved = max_column_size * max_column_size
124
- return (total_hessian_elems + inverse_reserved ) * bytes_per_weight
127
+ inverse_reserved = overall_max_column_size * overall_max_column_size
128
+ return (max_total_hessian_elems + inverse_reserved ) * bytes_per_weight
125
129
126
130
127
131
def quantization_memory_requirement (model : torch .nn .Module ) -> int :
0 commit comments