4343)
4444from huggingface_hub .utils import EntryNotFoundError
4545from paddle import Tensor
46- from paddle .distributed import fleet
4746from paddle .distributed .fleet .meta_parallel .parallel_layers import (
4847 PipelineLayer ,
4948 SharedLayerDesc ,
5958 ASYMMETRY_QUANT_SCALE_MIN ,
6059 CONFIG_NAME ,
6160 LEGACY_CONFIG_NAME ,
62- MOMENT1_KEYNAME ,
63- MOMENT2_KEYNAME ,
6461 PADDLE_WEIGHTS_INDEX_NAME ,
6562 PADDLE_WEIGHTS_NAME ,
6663 PYTORCH_WEIGHTS_INDEX_NAME ,
7471from paddlenlp .utils .log import logger
7572
7673from ..generation import GenerationConfig , GenerationMixin
77- from ..quantization .checkpoint_quantization_utils import (
78- asymmetry_qdq_weight ,
79- group_wise_quant_dequant ,
80- qdq_weight ,
81- split_int8 ,
82- )
74+ from ..quantization .unified_checkpoint_quantization import dequant_unified_optimizer
8375from ..utils import device_guard
8476from ..utils .download import resolve_file_path
8577from .configuration_utils import PretrainedConfig
@@ -332,100 +324,6 @@ def get_parameter_dtype(parameter: nn.Layer) -> paddle.dtype:
332324 return last_dtype
333325
334326
335- def dequant_unified_optimizer (state_dict , ckpt_quant_stage , scale_dict ):
336- rank , world_size = - 1 , 1
337- if paddle .distributed .get_world_size () > 1 :
338- hcg = fleet .get_hybrid_communicate_group ()
339- tp_group = hcg .get_model_parallel_group ()
340- rank , world_size = tp_group .rank , tp_group .nranks
341-
342- if ckpt_quant_stage == "O1" :
343- # set eps
344- eps = 1e-8
345- for quant_key in state_dict .keys ():
346- is_moment1 = MOMENT1_KEYNAME in quant_key
347- is_moment2 = MOMENT2_KEYNAME in quant_key
348- if is_moment1 :
349- # dequant m1
350- scale_key = quant_key + SYMMETRY_QUANT_SCALE
351- weight = state_dict [quant_key ]
352- scales = scale_dict [scale_key ]
353- weight , _ = qdq_weight (
354- weight ,
355- scales = scales ,
356- quant_bit = 8 ,
357- dequant = True ,
358- rank = rank ,
359- world_size = world_size ,
360- use_pd = True ,
361- )
362- state_dict [quant_key ] = weight
363- elif is_moment2 :
364- # dequant ratio
365- weight = state_dict [quant_key ]
366- min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
367- max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
368- mins , maxs = scale_dict [min_scale_key ], scale_dict [max_scale_key ]
369- weight , _ = asymmetry_qdq_weight (
370- weight ,
371- mins = mins ,
372- maxs = maxs ,
373- quant_bit = 8 ,
374- dequant = True ,
375- rank = rank ,
376- world_size = world_size ,
377- use_pd = True ,
378- )
379- # cal m2
380- weight = paddle .square (1.0 / weight - eps )
381- state_dict [quant_key ] = weight
382- elif ckpt_quant_stage == "O2" :
383- # set eps
384- eps = 1e-8
385- m1_state_dict = {}
386- for quant_key in state_dict .keys ():
387- if state_dict [quant_key ].dtype != paddle .int8 :
388- logger .info (f"{ quant_key } skip." )
389- continue
390- # split int8
391- weight = state_dict [quant_key ]
392- m1_quant , ratio_quant = split_int8 (weight .numpy ())
393- # dequant ratio
394- ratio_min_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MIN
395- ratio_max_scale_key = quant_key + ASYMMETRY_QUANT_SCALE_MAX
396- m1_scale_key = quant_key [: - len (MOMENT2_KEYNAME )] + MOMENT1_KEYNAME + SYMMETRY_QUANT_SCALE
397- m1_codebook = scale_dict [m1_scale_key ]
398- ratio_mins , ratio_maxs = scale_dict [ratio_min_scale_key ], scale_dict [ratio_max_scale_key ]
399- m1_weight = group_wise_quant_dequant (
400- m1_quant ,
401- mins = m1_codebook ,
402- maxs = None ,
403- quant_bits = 4 ,
404- quant = False ,
405- rank = rank ,
406- world_size = world_size ,
407- use_pd = True ,
408- symmetry = True ,
409- )
410- ratio_weight = group_wise_quant_dequant (
411- ratio_quant ,
412- mins = ratio_mins ,
413- maxs = ratio_maxs ,
414- quant_bits = 4 ,
415- quant = False ,
416- rank = rank ,
417- world_size = world_size ,
418- use_pd = True ,
419- )
420-
421- ratio_weight = paddle .square (1.0 / ratio_weight - eps )
422- state_dict [quant_key ] = ratio_weight
423- m1_state_dict [quant_key [: - len (MOMENT2_KEYNAME )] + MOMENT1_KEYNAME ] = m1_weight
424- state_dict .update (m1_state_dict )
425-
426- return state_dict
427-
428-
429327def _split_keys_evenly (keys : list , n : int ) -> list :
430328 """Split a list into n lists with an equal number of elements.
431329
@@ -471,9 +369,18 @@ def _load_part_state_dict(
471369 scale_dict = {}
472370 with safe_open (checkpoint_file , framework = "np" ) as f :
473371 for key in keys :
474- # non merge ckpt loading dont have filter key.
475- if key .endswith (SYMMETRY_QUANT_SCALE ) or (fliter_dict_keys is not None and key not in fliter_dict_keys ):
372+ # 1. non-merge ckpt loading dont have filter key.
373+ # 2. merge ckpt will skip quant scale by `fliter_dict_keys`
374+ if (
375+ key .endswith (SYMMETRY_QUANT_SCALE )
376+ or key .endswith (ASYMMETRY_QUANT_SCALE_MIN )
377+ or key .endswith (ASYMMETRY_QUANT_SCALE_MAX )
378+ ):
476379 continue
380+
381+ if fliter_dict_keys is not None and key not in fliter_dict_keys :
382+ continue
383+
477384 py_safe_slice_ = f .get_slice (key )
478385 if key in tensor_parallel_split_mapping :
479386 weight = tensor_parallel_split_mapping [key ](py_safe_slice_ )
@@ -485,7 +392,11 @@ def _load_part_state_dict(
485392 weight = weight ._copy_to (paddle .framework ._current_expected_place (), False )
486393 part_state_dict [key ] = weight
487394 for key in keys :
488- if key .endswith (SYMMETRY_QUANT_SCALE ):
395+ if (
396+ key .endswith (SYMMETRY_QUANT_SCALE )
397+ or key .endswith (ASYMMETRY_QUANT_SCALE_MIN )
398+ or key .endswith (ASYMMETRY_QUANT_SCALE_MAX )
399+ ):
489400 scale = f .get_tensor (key )
490401 with device_guard ():
491402 scale = paddle .Tensor (scale , zero_copy = True )
@@ -504,9 +415,6 @@ def load_state_dict(
504415 """
505416 Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
506417 """
507- quant = False
508- if ckpt_quant_stage != "O0" :
509- quant = "optimizer" in checkpoint_file
510418
511419 if tensor_parallel_split_mapping is None :
512420 tensor_parallel_split_mapping = {}
@@ -562,7 +470,9 @@ def load_state_dict(
562470 with device_guard ():
563471 state_dict [k ] = paddle .Tensor (state_dict .pop (k ), zero_copy = True )
564472
565- if quant :
473+ if len (scale_dict ) != 0 :
474+ if ckpt_quant_stage == "O0" :
475+ raise ValueError ('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"' )
566476 state_dict = dequant_unified_optimizer (state_dict , ckpt_quant_stage , scale_dict )
567477
568478 return state_dict
0 commit comments