Skip to content

Commit a6b2236

Browse files
committed
add comment
1 parent 3eaebbb commit a6b2236

File tree

2 files changed

+72
-67
lines changed

2 files changed

+72
-67
lines changed

paddlenlp/quantization/checkpoint_quantization_utils.py

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@
1717
import paddle
1818

1919

20-
# cal adam update ratio
20+
# cal part adam update ratio
2121
def cal_ratio(m, v, eps=1e-8):
2222
return 1 / (np.sqrt(v) + eps)
2323

@@ -29,8 +29,8 @@ def group_wise_quant_dequant(
2929
quant_bits=4,
3030
group_size=32,
3131
quant=True,
32-
rank=-1,
33-
world_size=1,
32+
tp_rank=-1,
33+
tp_degree=1,
3434
use_pd=False,
3535
symmetry=False,
3636
):
@@ -49,10 +49,10 @@ def group_wise_quant_dequant(
4949
Group size of group-wise quantization.
5050
quant (`bool`):
5151
True when quantization, False in dequantization.
52-
rank (`int`):
53-
Model parallel rank.
54-
world_size (`int`):
55-
Model parallel world size.
52+
tp_rank (`int`):
53+
Tensor parallel rank.
54+
tp_degree (`int`):
55+
Tensor parallel world size.
5656
use_pd (`bool`):
5757
Whether to use paddle caculation. If False will use numpy.
5858
symmetry (`bool`):
@@ -92,21 +92,28 @@ def group_wise_quant_dequant(
9292
else:
9393
new_scales = np.repeat(scales, repeats=group_size, axis=0)
9494

95-
if rank == -1:
95+
if tp_rank == -1:
9696
dequant_tensor = inputs.astype("float32") * new_scales / bnt
9797
elif len(new_scales.shape) == 0 or inputs.shape[-1] == new_scales.shape[-1]:
98+
# input tensor was row parallel in tp.
9899
dequant_tensor = (
99100
inputs.astype("float32")
100101
* new_scales[
101-
rank * new_scales.shape[0] // world_size : (rank + 1) * new_scales.shape[0] // world_size
102+
tp_rank * new_scales.shape[0] // tp_degree : (tp_rank + 1) * new_scales.shape[0] // tp_degree
102103
]
103104
/ bnt
104105
)
105106
else:
107+
# input tensor was column parallel in tp.
106108
dequant_tensor = (
107109
inputs.astype("float32")
108110
* new_scales[
109-
:, rank * new_scales.shape[-1] // world_size : (rank + 1) * new_scales.shape[-1] // world_size
111+
:,
112+
tp_rank
113+
* new_scales.shape[-1]
114+
// tp_degree : (tp_rank + 1)
115+
* new_scales.shape[-1]
116+
// tp_degree,
110117
]
111118
/ bnt
112119
)
@@ -120,22 +127,28 @@ def group_wise_quant_dequant(
120127
new_scales = np.repeat(scales, repeats=group_size, axis=0)
121128
new_mins = np.repeat(mins, repeats=group_size, axis=0)
122129

123-
if rank == -1:
130+
if tp_rank == -1:
124131
dequant_tensor = (inputs.astype("float32") / qmax * new_scales) + new_mins
125132
elif len(new_scales.shape) == 0 or inputs.shape[-1] == new_scales.shape[-1]:
133+
# input tensor was row parallel in tp.
126134
dequant_tensor = (
127135
inputs.astype("float32")
128136
/ qmax
129-
* new_scales[rank * new_scales.shape[0] // world_size : (rank + 1) * new_scales.shape[0] // world_size]
130-
) + new_mins[rank * new_mins.shape[0] // world_size : (rank + 1) * new_mins.shape[0] // world_size]
137+
* new_scales[
138+
tp_rank * new_scales.shape[0] // tp_degree : (tp_rank + 1) * new_scales.shape[0] // tp_degree
139+
]
140+
) + new_mins[tp_rank * new_mins.shape[0] // tp_degree : (tp_rank + 1) * new_mins.shape[0] // tp_degree]
131141
else:
142+
# input tensor was column parallel in tp.
132143
dequant_tensor = (
133144
inputs.astype("float32")
134145
/ qmax
135146
* new_scales[
136-
:, rank * new_scales.shape[-1] // world_size : (rank + 1) * new_scales.shape[-1] // world_size
147+
:, tp_rank * new_scales.shape[-1] // tp_degree : (tp_rank + 1) * new_scales.shape[-1] // tp_degree
137148
]
138-
) + new_mins[:, rank * new_mins.shape[-1] // world_size : (rank + 1) * new_mins.shape[-1] // world_size]
149+
) + new_mins[
150+
:, tp_rank * new_mins.shape[-1] // tp_degree : (tp_rank + 1) * new_mins.shape[-1] // tp_degree
151+
]
139152
return dequant_tensor
140153

141154

@@ -154,28 +167,29 @@ def split_int8(final):
154167

155168
int4_high = np.where(int4_high > 8, int4_high - 16, int4_high)
156169

157-
high_tensor = paddle.Tensor(int4_high, zero_copy=True)
158-
low_tensor = paddle.Tensor(int4_low, zero_copy=True)
170+
high_tensor = paddle.Tensor(int4_high)
171+
low_tensor = paddle.Tensor(int4_low)
159172

160173
return high_tensor, low_tensor
161174

162175

163176
# channel-wise min max scales calculation
164177
def cal_abs_min_max_channel(inputs, quant_axis=1):
178+
eps = 1e-8
165179
reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != quant_axis])
166180
abs_max_values = np.max(inputs, axis=reduce_axis)
167181
abs_min_values = np.min(inputs, axis=reduce_axis)
168182
abs_max_values = np.where(
169-
abs_max_values == np.array(0, dtype=inputs.dtype), np.array(1e-8, dtype=inputs.dtype), abs_max_values
183+
abs_max_values == np.array(0, dtype=inputs.dtype), np.array(eps, dtype=inputs.dtype), abs_max_values
170184
)
171185
abs_min_values = np.where(
172-
abs_min_values == np.array(0, dtype=inputs.dtype), np.array(1e-8, dtype=inputs.dtype), abs_min_values
186+
abs_min_values == np.array(0, dtype=inputs.dtype), np.array(eps, dtype=inputs.dtype), abs_min_values
173187
)
174188
return abs_max_values, abs_min_values
175189

176190

177191
def asymmetry_qdq_weight(
178-
x, quant_bit=8, quant_axis=-1, mins=None, maxs=None, dequant=False, rank=-1, world_size=1, use_pd=False
192+
x, quant_bit=8, quant_axis=-1, mins=None, maxs=None, dequant=False, tp_rank=-1, tp_degree=1, use_pd=False
179193
):
180194
"""
181195
channel-wise asymmetry quantization
@@ -192,9 +206,9 @@ def asymmetry_qdq_weight(
192206
Max scales tensor in asymmetry quantization.
193207
dequant (`bool`):
194208
True when dequantization, False in quantization.
195-
rank (`int`):
209+
tp_rank (`int`):
196210
Model parallel rank.
197-
world_size (`int`):
211+
tp_degree (`int`):
198212
Model parallel world size.
199213
use_pd (`bool`):
200214
Whether to use paddle caculation. If False will use numpy.
@@ -213,39 +227,47 @@ def asymmetry_qdq_weight(
213227
# dequant
214228
if not use_pd:
215229
if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]:
230+
# input tensor was row parallel in tp.
216231
qdq_x = (quant_x / bnt * scales) + mins
217232
else:
233+
# input tensor was column parallel in tp.
218234
qdq_x = (
219235
quant_x
220236
/ bnt
221-
* scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size]
222-
) + mins[rank * mins.shape[0] // world_size : (rank + 1) * mins.shape[0] // world_size]
237+
* scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree]
238+
) + mins[tp_rank * mins.shape[0] // tp_degree : (tp_rank + 1) * mins.shape[0] // tp_degree]
223239
return qdq_x.astype(np.float32), scales
224240
else:
225241
if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]:
242+
# input tensor was row parallel in tp.
226243
qdq_x = (quant_x / bnt * scales.unsqueeze(0).expand(quant_x.shape)) + mins
227244
else:
245+
# input tensor was column parallel in tp.
228246
qdq_x = (
229247
quant_x
230248
/ bnt
231-
* scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size]
249+
* scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree]
232250
.unsqueeze(0)
233251
.expand(quant_x.shape)
234-
) + mins[rank * mins.shape[0] // world_size : (rank + 1) * mins.shape[0] // world_size]
252+
) + mins[tp_rank * mins.shape[0] // tp_degree : (tp_rank + 1) * mins.shape[0] // tp_degree]
235253
return qdq_x.astype(paddle.float32), scales
236254

237255

238256
# channel-wise abs max calculation
239257
def cal_abs_max_channel(inputs, quant_axis=1):
258+
epsilon = 1e-8
240259
reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != quant_axis])
241260
abs_max_values = np.max(np.abs(inputs), axis=reduce_axis)
261+
# maybe all elements are zero in one group,
262+
# so set the scales from those group to an actual number
263+
# from divide 0.
242264
abs_max_values = np.where(
243-
abs_max_values == np.array(0, dtype=inputs.dtype), np.array(1e-8, dtype=inputs.dtype), abs_max_values
265+
abs_max_values == np.array(0, dtype=inputs.dtype), np.array(epsilon, dtype=inputs.dtype), abs_max_values
244266
)
245267
return abs_max_values
246268

247269

248-
def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, rank=-1, world_size=1, use_pd=False):
270+
def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, tp_rank=-1, tp_degree=1, use_pd=False):
249271
"""
250272
channel-wise symmetry quantization
251273
Args:
@@ -259,9 +281,9 @@ def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, rank=-
259281
Abs max scales tensor in symmetry quantization.
260282
dequant (`bool`):
261283
True when dequantization, False in quantization.
262-
rank (`int`):
284+
tp_rank (`int`):
263285
Model parallel rank.
264-
world_size (`int`):
286+
tp_degree (`int`):
265287
Model parallel world size.
266288
use_pd (`bool`):
267289
Whether to use paddle caculation. If False will use numpy.
@@ -279,23 +301,27 @@ def qdq_weight(x, quant_bit=8, quant_axis=-1, scales=None, dequant=False, rank=-
279301
# dequant
280302
if not use_pd:
281303
if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]:
304+
# input tensor was row parallel in tp.
282305
qdq_x = quant_x / bnt * scales
283306
else:
307+
# input tensor was column parallel in tp.
284308
qdq_x = (
285309
quant_x
286310
/ bnt
287-
* scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size]
311+
* scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree]
288312
)
289313
# fp32 , int8, int, fp32 or fp64
290314
return qdq_x.astype(np.float32), scales
291315
else:
292316
if len(scales.shape) == 0 or quant_x.shape[-1] == scales.shape[-1]:
317+
# input tensor was row parallel in tp.
293318
qdq_x = quant_x / bnt * scales.unsqueeze(0).expand(quant_x.shape)
294319
else:
320+
# input tensor was column parallel in tp.
295321
qdq_x = (
296322
quant_x
297323
/ bnt
298-
* scales[rank * scales.shape[0] // world_size : (rank + 1) * scales.shape[0] // world_size]
324+
* scales[tp_rank * scales.shape[0] // tp_degree : (tp_rank + 1) * scales.shape[0] // tp_degree]
299325
.unsqueeze(0)
300326
.expand(quant_x.shape)
301327
)

paddlenlp/quantization/unified_checkpoint_quantization.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,11 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
9696
ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
9797
ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
9898
m1_scale_key = quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE
99-
m1_codebook = scale_dict[m1_scale_key]
99+
m1_scales = scale_dict[m1_scale_key]
100100
ratio_mins, ratio_maxs = scale_dict[ratio_min_scale_key], scale_dict[ratio_max_scale_key]
101101
m1_weight = group_wise_quant_dequant(
102102
m1_quant,
103-
mins=m1_codebook,
103+
mins=m1_scales,
104104
maxs=None,
105105
quant_bits=4,
106106
quant=False,
@@ -134,16 +134,11 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
134134
quant = True
135135
del_key = []
136136
if quant and state_dict_type == "optimizer_weight":
137-
codebook_dict = {}
137+
scales_dict = {}
138138
opt_keys = state_dict.keys()
139-
if not async_save:
140-
all_bits, quant_bits = paddle.to_tensor(0.0), paddle.to_tensor(0.0)
141139
for k in opt_keys:
142140
momentum1 = k.endswith(MOMENT1_KEYNAME)
143141
momentum2 = k.endswith(MOMENT2_KEYNAME)
144-
k_size = state_dict[k].size
145-
if not async_save and (momentum1 or momentum2):
146-
all_bits += k_size * 4
147142

148143
quant_weight = None
149144

@@ -153,28 +148,29 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
153148
# m1: m1_quant_weight, m2: ratio
154149
m1_key = k.split("/")[0] + "/" + MOMENT1_KEYNAME
155150
ratio = cal_ratio(state_dict[m1_key], state_dict[k])
156-
m1_quant, codebook = qdq_weight(state_dict[m1_key], quant_bit=8)
151+
m1_quant, scales = qdq_weight(state_dict[m1_key], quant_bit=8)
157152
quant_weight, mins, maxs = asymmetry_qdq_weight(ratio, quant_bit=8)
158153
state_dict[m1_key] = m1_quant
159-
codebook_dict[m1_key + SYMMETRY_QUANT_SCALE] = codebook
160-
codebook_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = mins
161-
codebook_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = maxs
154+
scales_dict[m1_key + SYMMETRY_QUANT_SCALE] = scales
155+
scales_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = mins
156+
scales_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = maxs
162157
elif not momentum1:
163158
quant_weight = state_dict[k]
164159
elif ckpt_quant_stage == "O2":
165160
# m1: bw-wint4, 1/(sqrt(m2)+eps): bw-wint4
166161
if momentum2:
162+
# skip norm-like parameters
167163
if len(state_dict[k].shape) < 2:
168164
continue
169165
# m1: m1_quant_weight, m2: ratio
170166
m1_key = k.split("/")[0] + "/" + MOMENT1_KEYNAME
171167
ratio = cal_ratio(state_dict[m1_key], state_dict[k])
172-
m1_quant, m1_codebook = group_wise_quant_dequant(state_dict[m1_key], quant_bits=4, symmetry=True)
168+
m1_quant, m1_scales = group_wise_quant_dequant(state_dict[m1_key], quant_bits=4, symmetry=True)
173169
quant_weight, r_mins, r_maxs = group_wise_quant_dequant(ratio, quant_bits=4)
174170
quant_weight = merge_int4(m1_quant, quant_weight)
175-
codebook_dict[m1_key + SYMMETRY_QUANT_SCALE] = m1_codebook
176-
codebook_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = r_mins
177-
codebook_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = r_maxs
171+
scales_dict[m1_key + SYMMETRY_QUANT_SCALE] = m1_scales
172+
scales_dict[k + ASYMMETRY_QUANT_SCALE_MIN] = r_mins
173+
scales_dict[k + ASYMMETRY_QUANT_SCALE_MAX] = r_maxs
178174
del_key.append(m1_key)
179175
elif not momentum1:
180176
quant_weight = state_dict[k]
@@ -185,23 +181,6 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
185181
for k in del_key:
186182
state_dict.pop(k, None)
187183

188-
state_dict.update(codebook_dict)
189-
190-
if not async_save:
191-
if paddle.distributed.get_world_size() > 1:
192-
dist.all_reduce(all_bits)
193-
dist.all_reduce(quant_bits)
194-
195-
model_numel = all_bits / 4
196-
all_bits = model_numel * 7.0
197-
quant_bits_mw = quant_bits + model_numel * 6.0
198-
quant_bits = quant_bits + model_numel * 2.0
199-
logger.info(
200-
f"all bits: {all_bits.item()}, quant bits: {quant_bits.item()}, quant bits mw: {quant_bits_mw.item()}"
201-
)
202-
logger.info(f"quant ratio (w/o Master Weight): {(all_bits.item() - quant_bits.item()) / all_bits.item()}")
203-
logger.info(
204-
f"quant ratio (w/ Master Weight): {(all_bits.item() - quant_bits_mw.item()) / all_bits.item()}"
205-
)
184+
state_dict.update(scales_dict)
206185

207186
return state_dict

0 commit comments

Comments
 (0)