Skip to content

Commit b312634

Browse files
committed
1 parent d4acbfc commit b312634

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

paddlenlp/peft/lora/lora_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,12 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
231231

232232
if self.is_pipelinemodel:
233233
self.model._single_to_pp_mapping = None
234-
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
234+
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
235235
merge_tensor_parallel = False
236236
logger.warning(
237237
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
238238
)
239-
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
239+
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
240240
merge_tensor_parallel = False
241241
logger.warning(
242242
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False."

paddlenlp/quantization/quantization_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def convert_to_quantize_state_dict_with_check(state_dict, quantization_linear_li
109109
raise ValueError(
110110
f"{quant_weight_name} should be {paddle.int8} in state_dict but received dtype {state_dict[quant_weight_name].dtype}."
111111
)
112-
if state_dict[quant_scale_name].dtype != paddle.float32:
112+
if (
113+
state_dict[quant_scale_name].dtype != paddle.float16
114+
and state_dict[quant_scale_name].dtype != paddle.bfloat16
115+
):
113116
raise ValueError(
114-
f"{quant_scale_name} should be {paddle.float32} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
117+
f"{quant_scale_name} should be {paddle.float16} or {paddle.bfloat16} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
115118
)
116119
elif weight_name in state_dict:
117120
target_weight = state_dict.pop(weight_name).cast(dtype)

paddlenlp/transformers/model_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,9 +1759,15 @@ def _load_pretrained_model(
17591759
loaded_keys, quantization_linear_list, config.quantization_config
17601760
)
17611761
if keep_in_fp32_modules is None:
1762-
keep_in_fp32_modules = ["quant_scale"]
1762+
keep_in_fp32_modules = (
1763+
["quant_scale"] if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] else None
1764+
)
17631765
else:
1764-
keep_in_fp32_modules += ["quant_scale"]
1766+
keep_in_fp32_modules = (
1767+
keep_in_fp32_modules + ["quant_scale"]
1768+
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
1769+
else keep_in_fp32_modules
1770+
)
17651771

17661772
missing_keys = list(set(expected_keys) - set(loaded_keys))
17671773
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
@@ -2173,7 +2179,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
21732179
logger.info("Loaded weights file from disk, setting weights to model.")
21742180

21752181
# Check if `_keep_in_fp32_modules` is not None
2176-
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and dtype == "float16"
2182+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
2183+
dtype == "float16" or dtype == "bfloat16"
2184+
)
21772185

21782186
if is_sharded:
21792187
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]

0 commit comments

Comments
 (0)