Skip to content

Commit 4947a8c

Browse files
committed
refactor code
1 parent ffd0823 commit 4947a8c

File tree

5 files changed

+124
-115
lines changed

5 files changed

+124
-115
lines changed

paddlenlp/trainer/unified_checkpoint/quantization.py renamed to paddlenlp/quantization/unified_checkpoint_quantization.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
import paddle
1616
import paddle.distributed as dist
17+
from paddle.distributed import fleet
1718

1819
from paddlenlp.quantization.checkpoint_quantization_utils import (
1920
asymmetry_qdq_weight,
2021
cal_ratio,
2122
group_wise_quant_dequant,
2223
merge_int4,
2324
qdq_weight,
25+
split_int8,
2426
)
2527
from paddlenlp.utils.env import (
2628
ASYMMETRY_QUANT_SCALE_MAX,
@@ -32,6 +34,100 @@
3234
from paddlenlp.utils.log import logger
3335

3436

37+
def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
38+
rank, world_size = -1, 1
39+
if paddle.distributed.get_world_size() > 1:
40+
hcg = fleet.get_hybrid_communicate_group()
41+
tp_group = hcg.get_model_parallel_group()
42+
rank, world_size = tp_group.rank, tp_group.nranks
43+
44+
if ckpt_quant_stage == "O1":
45+
# set eps
46+
eps = 1e-8
47+
for quant_key in state_dict.keys():
48+
is_moment1 = MOMENT1_KEYNAME in quant_key
49+
is_moment2 = MOMENT2_KEYNAME in quant_key
50+
if is_moment1:
51+
# dequant m1
52+
scale_key = quant_key + SYMMETRY_QUANT_SCALE
53+
weight = state_dict[quant_key]
54+
scales = scale_dict[scale_key]
55+
weight, _ = qdq_weight(
56+
weight,
57+
scales=scales,
58+
quant_bit=8,
59+
dequant=True,
60+
rank=rank,
61+
world_size=world_size,
62+
use_pd=True,
63+
)
64+
state_dict[quant_key] = weight
65+
elif is_moment2:
66+
# dequant ratio
67+
weight = state_dict[quant_key]
68+
min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
69+
max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
70+
mins, maxs = scale_dict[min_scale_key], scale_dict[max_scale_key]
71+
weight, _ = asymmetry_qdq_weight(
72+
weight,
73+
mins=mins,
74+
maxs=maxs,
75+
quant_bit=8,
76+
dequant=True,
77+
rank=rank,
78+
world_size=world_size,
79+
use_pd=True,
80+
)
81+
# cal m2
82+
weight = paddle.square(1.0 / weight - eps)
83+
state_dict[quant_key] = weight
84+
elif ckpt_quant_stage == "O2":
85+
# set eps
86+
eps = 1e-8
87+
m1_state_dict = {}
88+
for quant_key in state_dict.keys():
89+
if state_dict[quant_key].dtype != paddle.int8:
90+
logger.info(f"{quant_key} skip.")
91+
continue
92+
# split int8
93+
weight = state_dict[quant_key]
94+
m1_quant, ratio_quant = split_int8(weight.numpy())
95+
# dequant ratio
96+
ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
97+
ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
98+
m1_scale_key = quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE
99+
m1_codebook = scale_dict[m1_scale_key]
100+
ratio_mins, ratio_maxs = scale_dict[ratio_min_scale_key], scale_dict[ratio_max_scale_key]
101+
m1_weight = group_wise_quant_dequant(
102+
m1_quant,
103+
mins=m1_codebook,
104+
maxs=None,
105+
quant_bits=4,
106+
quant=False,
107+
rank=rank,
108+
world_size=world_size,
109+
use_pd=True,
110+
symmetry=True,
111+
)
112+
ratio_weight = group_wise_quant_dequant(
113+
ratio_quant,
114+
mins=ratio_mins,
115+
maxs=ratio_maxs,
116+
quant_bits=4,
117+
quant=False,
118+
rank=rank,
119+
world_size=world_size,
120+
use_pd=True,
121+
)
122+
123+
ratio_weight = paddle.square(1.0 / ratio_weight - eps)
124+
state_dict[quant_key] = ratio_weight
125+
m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight
126+
state_dict.update(m1_state_dict)
127+
128+
return state_dict
129+
130+
35131
def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async_save=False):
36132
quant = False
37133
if ckpt_quant_stage != "O0":

paddlenlp/trainer/unified_checkpoint/async_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727
if is_safetensors_available():
2828
from safetensors.numpy import save_file as safe_save_file
2929

30-
from .quantization import quant_unified_optimizer
30+
from paddlenlp.quantization.unified_checkpoint_quantization import (
31+
quant_unified_optimizer,
32+
)
33+
3134
from .shared_memory_utils import (
3235
_read_state_dict_from_shm,
3336
_traverse_copy_to_shm,

paddlenlp/transformers/model_utils.py

Lines changed: 20 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
)
4444
from huggingface_hub.utils import EntryNotFoundError
4545
from paddle import Tensor
46-
from paddle.distributed import fleet
4746
from paddle.distributed.fleet.meta_parallel.parallel_layers import (
4847
PipelineLayer,
4948
SharedLayerDesc,
@@ -59,8 +58,6 @@
5958
ASYMMETRY_QUANT_SCALE_MIN,
6059
CONFIG_NAME,
6160
LEGACY_CONFIG_NAME,
62-
MOMENT1_KEYNAME,
63-
MOMENT2_KEYNAME,
6461
PADDLE_WEIGHTS_INDEX_NAME,
6562
PADDLE_WEIGHTS_NAME,
6663
PYTORCH_WEIGHTS_INDEX_NAME,
@@ -74,12 +71,7 @@
7471
from paddlenlp.utils.log import logger
7572

7673
from ..generation import GenerationConfig, GenerationMixin
77-
from ..quantization.checkpoint_quantization_utils import (
78-
asymmetry_qdq_weight,
79-
group_wise_quant_dequant,
80-
qdq_weight,
81-
split_int8,
82-
)
74+
from ..quantization.unified_checkpoint_quantization import dequant_unified_optimizer
8375
from ..utils import device_guard
8476
from ..utils.download import resolve_file_path
8577
from .configuration_utils import PretrainedConfig
@@ -332,100 +324,6 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
332324
return last_dtype
333325

334326

335-
def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
336-
rank, world_size = -1, 1
337-
if paddle.distributed.get_world_size() > 1:
338-
hcg = fleet.get_hybrid_communicate_group()
339-
tp_group = hcg.get_model_parallel_group()
340-
rank, world_size = tp_group.rank, tp_group.nranks
341-
342-
if ckpt_quant_stage == "O1":
343-
# set eps
344-
eps = 1e-8
345-
for quant_key in state_dict.keys():
346-
is_moment1 = MOMENT1_KEYNAME in quant_key
347-
is_moment2 = MOMENT2_KEYNAME in quant_key
348-
if is_moment1:
349-
# dequant m1
350-
scale_key = quant_key + SYMMETRY_QUANT_SCALE
351-
weight = state_dict[quant_key]
352-
scales = scale_dict[scale_key]
353-
weight, _ = qdq_weight(
354-
weight,
355-
scales=scales,
356-
quant_bit=8,
357-
dequant=True,
358-
rank=rank,
359-
world_size=world_size,
360-
use_pd=True,
361-
)
362-
state_dict[quant_key] = weight
363-
elif is_moment2:
364-
# dequant ratio
365-
weight = state_dict[quant_key]
366-
min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
367-
max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
368-
mins, maxs = scale_dict[min_scale_key], scale_dict[max_scale_key]
369-
weight, _ = asymmetry_qdq_weight(
370-
weight,
371-
mins=mins,
372-
maxs=maxs,
373-
quant_bit=8,
374-
dequant=True,
375-
rank=rank,
376-
world_size=world_size,
377-
use_pd=True,
378-
)
379-
# cal m2
380-
weight = paddle.square(1.0 / weight - eps)
381-
state_dict[quant_key] = weight
382-
elif ckpt_quant_stage == "O2":
383-
# set eps
384-
eps = 1e-8
385-
m1_state_dict = {}
386-
for quant_key in state_dict.keys():
387-
if state_dict[quant_key].dtype != paddle.int8:
388-
logger.info(f"{quant_key} skip.")
389-
continue
390-
# split int8
391-
weight = state_dict[quant_key]
392-
m1_quant, ratio_quant = split_int8(weight.numpy())
393-
# dequant ratio
394-
ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
395-
ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
396-
m1_scale_key = quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE
397-
m1_codebook = scale_dict[m1_scale_key]
398-
ratio_mins, ratio_maxs = scale_dict[ratio_min_scale_key], scale_dict[ratio_max_scale_key]
399-
m1_weight = group_wise_quant_dequant(
400-
m1_quant,
401-
mins=m1_codebook,
402-
maxs=None,
403-
quant_bits=4,
404-
quant=False,
405-
rank=rank,
406-
world_size=world_size,
407-
use_pd=True,
408-
symmetry=True,
409-
)
410-
ratio_weight = group_wise_quant_dequant(
411-
ratio_quant,
412-
mins=ratio_mins,
413-
maxs=ratio_maxs,
414-
quant_bits=4,
415-
quant=False,
416-
rank=rank,
417-
world_size=world_size,
418-
use_pd=True,
419-
)
420-
421-
ratio_weight = paddle.square(1.0 / ratio_weight - eps)
422-
state_dict[quant_key] = ratio_weight
423-
m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight
424-
state_dict.update(m1_state_dict)
425-
426-
return state_dict
427-
428-
429327
def _split_keys_evenly(keys: list, n: int) -> list:
430328
"""Split a list into n lists with an equal number of elements.
431329
@@ -471,9 +369,18 @@ def _load_part_state_dict(
471369
scale_dict = {}
472370
with safe_open(checkpoint_file, framework="np") as f:
473371
for key in keys:
474-
# non merge ckpt loading dont have filter key.
475-
if key.endswith(SYMMETRY_QUANT_SCALE) or (fliter_dict_keys is not None and key not in fliter_dict_keys):
372+
# 1. non-merge ckpt loading dont have filter key.
373+
# 2. merge ckpt will skip quant scale by `fliter_dict_keys`
374+
if (
375+
key.endswith(SYMMETRY_QUANT_SCALE)
376+
or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
377+
or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
378+
):
476379
continue
380+
381+
if fliter_dict_keys is not None and key not in fliter_dict_keys:
382+
continue
383+
477384
py_safe_slice_ = f.get_slice(key)
478385
if key in tensor_parallel_split_mapping:
479386
weight = tensor_parallel_split_mapping[key](py_safe_slice_)
@@ -485,7 +392,11 @@ def _load_part_state_dict(
485392
weight = weight._copy_to(paddle.framework._current_expected_place(), False)
486393
part_state_dict[key] = weight
487394
for key in keys:
488-
if key.endswith(SYMMETRY_QUANT_SCALE):
395+
if (
396+
key.endswith(SYMMETRY_QUANT_SCALE)
397+
or key.endswith(ASYMMETRY_QUANT_SCALE_MIN)
398+
or key.endswith(ASYMMETRY_QUANT_SCALE_MAX)
399+
):
489400
scale = f.get_tensor(key)
490401
with device_guard():
491402
scale = paddle.Tensor(scale, zero_copy=True)
@@ -504,9 +415,6 @@ def load_state_dict(
504415
"""
505416
Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
506417
"""
507-
quant = False
508-
if ckpt_quant_stage != "O0":
509-
quant = "optimizer" in checkpoint_file
510418

511419
if tensor_parallel_split_mapping is None:
512420
tensor_parallel_split_mapping = {}
@@ -562,7 +470,9 @@ def load_state_dict(
562470
with device_guard():
563471
state_dict[k] = paddle.Tensor(state_dict.pop(k), zero_copy=True)
564472

565-
if quant:
473+
if len(scale_dict) != 0:
474+
if ckpt_quant_stage == "O0":
475+
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
566476
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict)
567477

568478
return state_dict

paddlenlp/utils/env.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,6 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
117117
MOMENT2_KEYNAME = "moment2_0"
118118
BETA1_KEYNAME = "beta1_pow_acc_0"
119119
BETA2_KEYNAME = "beta2_pow_acc_0"
120-
SYMMETRY_QUANT_SCALE = "_codebook"
121-
ASYMMETRY_QUANT_SCALE_MIN = "_min_codebook"
122-
ASYMMETRY_QUANT_SCALE_MAX = "_max_codebook"
120+
SYMMETRY_QUANT_SCALE = "@scales"
121+
ASYMMETRY_QUANT_SCALE_MIN = "@min_scales"
122+
ASYMMETRY_QUANT_SCALE_MAX = "@max_scales"

tests/fixtures/llm/finetune.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ ckpt_quant:
8989
do_eval: true
9090
use_flash_attention: true
9191
unified_checkpoint: true
92-
unified_checkpoint_config: "remove_master_weight"
92+
unified_checkpoint_config: "async_save remove_master_weight"
9393
disable_tqdm: true
9494
load_best_model_at_end: true
9595
eval_with_do_generation: false

0 commit comments

Comments
 (0)