1- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
1+ # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1717import paddle
1818
1919
20- # cal adam update ratio
20+ # cal part adam update ratio
2121def cal_ratio (m , v , eps = 1e-8 ):
2222 return 1 / (np .sqrt (v ) + eps )
2323
@@ -29,8 +29,8 @@ def group_wise_quant_dequant(
2929 quant_bits = 4 ,
3030 group_size = 32 ,
3131 quant = True ,
32- rank = - 1 ,
33- world_size = 1 ,
32+ tp_rank = - 1 ,
33+ tp_degree = 1 ,
3434 use_pd = False ,
3535 symmetry = False ,
3636):
@@ -49,10 +49,10 @@ def group_wise_quant_dequant(
4949 Group size of group-wise quantization.
5050 quant (`bool`):
5151 True when quantization, False in dequantization.
52- rank (`int`):
53- Model parallel rank.
54- world_size (`int`):
55- Model parallel world size.
52+ tp_rank (`int`):
53+ Tensor parallel rank.
54+ tp_degree (`int`):
55+ Tensor parallel world size.
5656 use_pd (`bool`):
5757 Whether to use paddle caculation. If False will use numpy.
5858 symmetry (`bool`):
@@ -92,21 +92,28 @@ def group_wise_quant_dequant(
9292 else :
9393 new_scales = np .repeat (scales , repeats = group_size , axis = 0 )
9494
95- if rank == - 1 :
95+ if tp_rank == - 1 :
9696 dequant_tensor = inputs .astype ("float32" ) * new_scales / bnt
9797 elif len (new_scales .shape ) == 0 or inputs .shape [- 1 ] == new_scales .shape [- 1 ]:
98+ # input tensor was row parallel in tp.
9899 dequant_tensor = (
99100 inputs .astype ("float32" )
100101 * new_scales [
101- rank * new_scales .shape [0 ] // world_size : (rank + 1 ) * new_scales .shape [0 ] // world_size
102+ tp_rank * new_scales .shape [0 ] // tp_degree : (tp_rank + 1 ) * new_scales .shape [0 ] // tp_degree
102103 ]
103104 / bnt
104105 )
105106 else :
107+ # input tensor was column parallel in tp.
106108 dequant_tensor = (
107109 inputs .astype ("float32" )
108110 * new_scales [
109- :, rank * new_scales .shape [- 1 ] // world_size : (rank + 1 ) * new_scales .shape [- 1 ] // world_size
111+ :,
112+ tp_rank
113+ * new_scales .shape [- 1 ]
114+ // tp_degree : (tp_rank + 1 )
115+ * new_scales .shape [- 1 ]
116+ // tp_degree ,
110117 ]
111118 / bnt
112119 )
@@ -120,22 +127,28 @@ def group_wise_quant_dequant(
120127 new_scales = np .repeat (scales , repeats = group_size , axis = 0 )
121128 new_mins = np .repeat (mins , repeats = group_size , axis = 0 )
122129
123- if rank == - 1 :
130+ if tp_rank == - 1 :
124131 dequant_tensor = (inputs .astype ("float32" ) / qmax * new_scales ) + new_mins
125132 elif len (new_scales .shape ) == 0 or inputs .shape [- 1 ] == new_scales .shape [- 1 ]:
133+ # input tensor was row parallel in tp.
126134 dequant_tensor = (
127135 inputs .astype ("float32" )
128136 / qmax
129- * new_scales [rank * new_scales .shape [0 ] // world_size : (rank + 1 ) * new_scales .shape [0 ] // world_size ]
130- ) + new_mins [rank * new_mins .shape [0 ] // world_size : (rank + 1 ) * new_mins .shape [0 ] // world_size ]
137+ * new_scales [
138+ tp_rank * new_scales .shape [0 ] // tp_degree : (tp_rank + 1 ) * new_scales .shape [0 ] // tp_degree
139+ ]
140+ ) + new_mins [tp_rank * new_mins .shape [0 ] // tp_degree : (tp_rank + 1 ) * new_mins .shape [0 ] // tp_degree ]
131141 else :
142+ # input tensor was column parallel in tp.
132143 dequant_tensor = (
133144 inputs .astype ("float32" )
134145 / qmax
135146 * new_scales [
136- :, rank * new_scales .shape [- 1 ] // world_size : (rank + 1 ) * new_scales .shape [- 1 ] // world_size
147+ :, tp_rank * new_scales .shape [- 1 ] // tp_degree : (tp_rank + 1 ) * new_scales .shape [- 1 ] // tp_degree
137148 ]
138- ) + new_mins [:, rank * new_mins .shape [- 1 ] // world_size : (rank + 1 ) * new_mins .shape [- 1 ] // world_size ]
149+ ) + new_mins [
150+ :, tp_rank * new_mins .shape [- 1 ] // tp_degree : (tp_rank + 1 ) * new_mins .shape [- 1 ] // tp_degree
151+ ]
139152 return dequant_tensor
140153
141154
@@ -154,28 +167,29 @@ def split_int8(final):
154167
155168 int4_high = np .where (int4_high > 8 , int4_high - 16 , int4_high )
156169
157- high_tensor = paddle .Tensor (int4_high , zero_copy = True )
158- low_tensor = paddle .Tensor (int4_low , zero_copy = True )
170+ high_tensor = paddle .Tensor (int4_high )
171+ low_tensor = paddle .Tensor (int4_low )
159172
160173 return high_tensor , low_tensor
161174
162175
163176# channel-wise min max scales calculation
164177def cal_abs_min_max_channel (inputs , quant_axis = 1 ):
178+ eps = 1e-8
165179 reduce_axis = tuple ([i for i in range (len (inputs .shape )) if i != quant_axis ])
166180 abs_max_values = np .max (inputs , axis = reduce_axis )
167181 abs_min_values = np .min (inputs , axis = reduce_axis )
168182 abs_max_values = np .where (
169- abs_max_values == np .array (0 , dtype = inputs .dtype ), np .array (1e-8 , dtype = inputs .dtype ), abs_max_values
183+ abs_max_values == np .array (0 , dtype = inputs .dtype ), np .array (eps , dtype = inputs .dtype ), abs_max_values
170184 )
171185 abs_min_values = np .where (
172- abs_min_values == np .array (0 , dtype = inputs .dtype ), np .array (1e-8 , dtype = inputs .dtype ), abs_min_values
186+ abs_min_values == np .array (0 , dtype = inputs .dtype ), np .array (eps , dtype = inputs .dtype ), abs_min_values
173187 )
174188 return abs_max_values , abs_min_values
175189
176190
177191def asymmetry_qdq_weight (
178- x , quant_bit = 8 , quant_axis = - 1 , mins = None , maxs = None , dequant = False , rank = - 1 , world_size = 1 , use_pd = False
192+ x , quant_bit = 8 , quant_axis = - 1 , mins = None , maxs = None , dequant = False , tp_rank = - 1 , tp_degree = 1 , use_pd = False
179193):
180194 """
181195 channel-wise asymmetry quantization
@@ -192,9 +206,9 @@ def asymmetry_qdq_weight(
192206 Max scales tensor in asymmetry quantization.
193207 dequant (`bool`):
194208 True when dequantization, False in quantization.
195- rank (`int`):
209+ tp_rank (`int`):
196210 Model parallel rank.
197- world_size (`int`):
211+ tp_degree (`int`):
198212 Model parallel world size.
199213 use_pd (`bool`):
200214 Whether to use paddle caculation. If False will use numpy.
@@ -213,39 +227,47 @@ def asymmetry_qdq_weight(
213227 # dequant
214228 if not use_pd :
215229 if len (scales .shape ) == 0 or quant_x .shape [- 1 ] == scales .shape [- 1 ]:
230+ # input tensor was row parallel in tp.
216231 qdq_x = (quant_x / bnt * scales ) + mins
217232 else :
233+ # input tensor was column parallel in tp.
218234 qdq_x = (
219235 quant_x
220236 / bnt
221- * scales [rank * scales .shape [0 ] // world_size : (rank + 1 ) * scales .shape [0 ] // world_size ]
222- ) + mins [rank * mins .shape [0 ] // world_size : (rank + 1 ) * mins .shape [0 ] // world_size ]
237+ * scales [tp_rank * scales .shape [0 ] // tp_degree : (tp_rank + 1 ) * scales .shape [0 ] // tp_degree ]
238+ ) + mins [tp_rank * mins .shape [0 ] // tp_degree : (tp_rank + 1 ) * mins .shape [0 ] // tp_degree ]
223239 return qdq_x .astype (np .float32 ), scales
224240 else :
225241 if len (scales .shape ) == 0 or quant_x .shape [- 1 ] == scales .shape [- 1 ]:
242+ # input tensor was row parallel in tp.
226243 qdq_x = (quant_x / bnt * scales .unsqueeze (0 ).expand (quant_x .shape )) + mins
227244 else :
245+ # input tensor was column parallel in tp.
228246 qdq_x = (
229247 quant_x
230248 / bnt
231- * scales [rank * scales .shape [0 ] // world_size : (rank + 1 ) * scales .shape [0 ] // world_size ]
249+ * scales [tp_rank * scales .shape [0 ] // tp_degree : (tp_rank + 1 ) * scales .shape [0 ] // tp_degree ]
232250 .unsqueeze (0 )
233251 .expand (quant_x .shape )
234- ) + mins [rank * mins .shape [0 ] // world_size : (rank + 1 ) * mins .shape [0 ] // world_size ]
252+ ) + mins [tp_rank * mins .shape [0 ] // tp_degree : (tp_rank + 1 ) * mins .shape [0 ] // tp_degree ]
235253 return qdq_x .astype (paddle .float32 ), scales
236254
237255
238256# channel-wise abs max calculation
239257def cal_abs_max_channel (inputs , quant_axis = 1 ):
258+ epsilon = 1e-8
240259 reduce_axis = tuple ([i for i in range (len (inputs .shape )) if i != quant_axis ])
241260 abs_max_values = np .max (np .abs (inputs ), axis = reduce_axis )
261+ # maybe all elements are zero in one group,
262+ # so set the scales from those group to an actual number
263+ # from divide 0.
242264 abs_max_values = np .where (
243- abs_max_values == np .array (0 , dtype = inputs .dtype ), np .array (1e-8 , dtype = inputs .dtype ), abs_max_values
265+ abs_max_values == np .array (0 , dtype = inputs .dtype ), np .array (epsilon , dtype = inputs .dtype ), abs_max_values
244266 )
245267 return abs_max_values
246268
247269
248- def qdq_weight (x , quant_bit = 8 , quant_axis = - 1 , scales = None , dequant = False , rank = - 1 , world_size = 1 , use_pd = False ):
270+ def qdq_weight (x , quant_bit = 8 , quant_axis = - 1 , scales = None , dequant = False , tp_rank = - 1 , tp_degree = 1 , use_pd = False ):
249271 """
250272 channel-wise symmetry quantization
251273 Args:
@@ -259,9 +281,9 @@ def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, rank=-
259281 Abs max scales tensor in symmetry quantization.
260282 dequant (`bool`):
261283 True when dequantization, False in quantization.
262- rank (`int`):
284+ tp_rank (`int`):
263285 Model parallel rank.
264- world_size (`int`):
286+ tp_degree (`int`):
265287 Model parallel world size.
266288 use_pd (`bool`):
267289 Whether to use paddle caculation. If False will use numpy.
@@ -279,23 +301,27 @@ def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, rank=-
279301 # dequant
280302 if not use_pd :
281303 if len (scales .shape ) == 0 or quant_x .shape [- 1 ] == scales .shape [- 1 ]:
304+ # input tensor was row parallel in tp.
282305 qdq_x = quant_x / bnt * scales
283306 else :
307+ # input tensor was column parallel in tp.
284308 qdq_x = (
285309 quant_x
286310 / bnt
287- * scales [rank * scales .shape [0 ] // world_size : (rank + 1 ) * scales .shape [0 ] // world_size ]
311+ * scales [tp_rank * scales .shape [0 ] // tp_degree : (tp_rank + 1 ) * scales .shape [0 ] // tp_degree ]
288312 )
289313 # fp32 , int8, int, fp32 or fp64
290314 return qdq_x .astype (np .float32 ), scales
291315 else :
292316 if len (scales .shape ) == 0 or quant_x .shape [- 1 ] == scales .shape [- 1 ]:
317+ # input tensor was row parallel in tp.
293318 qdq_x = quant_x / bnt * scales .unsqueeze (0 ).expand (quant_x .shape )
294319 else :
320+ # input tensor was column parallel in tp.
295321 qdq_x = (
296322 quant_x
297323 / bnt
298- * scales [rank * scales .shape [0 ] // world_size : (rank + 1 ) * scales .shape [0 ] // world_size ]
324+ * scales [tp_rank * scales .shape [0 ] // tp_degree : (tp_rank + 1 ) * scales .shape [0 ] // tp_degree ]
299325 .unsqueeze (0 )
300326 .expand (quant_x .shape )
301327 )
0 commit comments