Skip to content

Commit 4069f22

Browse files
authored
fix lora (#7824)
1 parent abb0d3c commit 4069f22

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

paddlenlp/peft/lora/lora_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,12 +230,12 @@ def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = Fal
230230

231231
if self.is_pipelinemodel:
232232
self.model._single_to_pp_mapping = None
233-
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
233+
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
234234
merge_tensor_parallel = False
235235
logger.warning(
236236
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
237237
)
238-
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degre > 1:
238+
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
239239
merge_tensor_parallel = False
240240
logger.warning(
241241
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False."

paddlenlp/quantization/quantization_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,12 @@ def convert_to_quantize_state_dict_with_check(state_dict, quantization_linear_li
109109
raise ValueError(
110110
f"{quant_weight_name} should be {paddle.int8} in state_dict but received dtype {state_dict[quant_weight_name].dtype}."
111111
)
112-
if state_dict[quant_scale_name].dtype != paddle.float32:
112+
if (
113+
state_dict[quant_scale_name].dtype != paddle.float16
114+
and state_dict[quant_scale_name].dtype != paddle.bfloat16
115+
):
113116
raise ValueError(
114-
f"{quant_scale_name} should be {paddle.float32} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
117+
f"{quant_scale_name} should be {paddle.float16} or {paddle.bfloat16} in state_dict but received dtype {state_dict[quant_scale_name].dtype}."
115118
)
116119
elif weight_name in state_dict:
117120
target_weight = state_dict.pop(weight_name).cast(dtype)

paddlenlp/transformers/model_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1776,9 +1776,15 @@ def _load_pretrained_model(
17761776
loaded_keys, quantization_linear_list, config.quantization_config
17771777
)
17781778
if keep_in_fp32_modules is None:
1779-
keep_in_fp32_modules = ["quant_scale"]
1779+
keep_in_fp32_modules = (
1780+
["quant_scale"] if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"] else None
1781+
)
17801782
else:
1781-
keep_in_fp32_modules += ["quant_scale"]
1783+
keep_in_fp32_modules = (
1784+
keep_in_fp32_modules + ["quant_scale"]
1785+
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
1786+
else keep_in_fp32_modules
1787+
)
17821788

17831789
missing_keys = list(set(expected_keys) - set(loaded_keys))
17841790
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
@@ -2200,7 +2206,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
22002206
logger.info("Loaded weights file from disk, setting weights to model.")
22012207

22022208
# Check if `_keep_in_fp32_modules` is not None
2203-
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and dtype == "float16"
2209+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
2210+
dtype == "float16" or dtype == "bfloat16"
2211+
)
22042212

22052213
if is_sharded:
22062214
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]

0 commit comments

Comments
 (0)