Skip to content

Commit a6909dd

Browse files
lixcliMangodadada
authored andcommitted
[Inference] update fakequant support (PaddlePaddle#9047)
* 1. add a8w8(fp8) a8w8c8(int8) quant_type support 2. add llama3.1 and qwen2 ptq config 3. update quantization.md * fix load_quant_model bug * fix load quant bug * update ll/README.md * remove useless code * update quant observer config * resolve wrong modify * fix prepare_qconfig * remove unuse files
1 parent c1de4c9 commit a6909dd

File tree

13 files changed

+34
-274
lines changed

13 files changed

+34
-274
lines changed

llm/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,13 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo
224224

225225
```shell
226226
# PTQ 量化启动命令参考
227-
python run_finetune.py ./config/llama/ptq_argument.json
227+
python run_finetune.py ./config/llama/ptq_argument.json
228228

229229
# GPTQ 量化启动命令参考
230-
python run_finetune.py ./config/llama/ptq_argument.json
230+
python run_finetune.py ./config/llama/ptq_argument.json
231231

232232
# W8A8C8(INT)量化启动命令参考
233-
python run_finetune.py ./config/llama/ptq_c8_argument.json
233+
python run_finetune.py ./config/llama/ptq_c8_argument.json
234234

235235
# W8A8(FP8)量化启动命令参考
236236
python run_finetune.py ./config/llama/fp8_ptq_argument.json

llm/docs/quantization.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 大模型量化教程
1+
p# 大模型量化教程
22

33
## 1.算法介绍
44

@@ -111,8 +111,8 @@ python run_finetune.py ./config/llama/ceval_quant_argument.json
111111
- `use_fp8`: 是否使用 FP8 量化,默认为空字符串。输入`"WA"`(不区分大小写)则将权重和激活的8位量化转换为 FP8量化。
112112
- `fp8_type`: FP8量化类型,长度应与`use_fp8`相同。默认为`["e4m3","e4m3"]`
113113
- `do_ptq`: 是否进行 PTQ 量化,默认为 False。
114-
- `weight_quant_method`: 权重量化方式,现可选 groupwise 或者 abs_max_channel_wise。
115-
- `act_quant_method`: 激活量化方式,现可选 avg 或者 abs_max。
114+
- `weight_quant_method`: 权重量化方式,INT8量化可选 groupwise 或者 abs_max_channel_wise,FP8量化可选 abs_max 或 avg
115+
- `act_quant_method`: 激活量化方式,INT8可选 avg 或者 abs_max,FP8量化可选 abs_max 或 avg
116116
- `cachekv_quant_method`: kvcache 量化方式,现可选 abs_max_headwise, avg_headwise。
117117
- `ptq_step`: PTQ 量化步数,也即模型前向次数,默认为32。
118118
- `shift`: 是否在 PTQ 量化前进行[Shift 策略](https://arxiv.org/abs/2304.09145),默认为 False。使用 Shift 策略需要设`do_ptq`为 True。

llm/experimental/ceval/default/eval.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -36,25 +36,20 @@ def run_eval_one_time(args, evaluator, take):
3636
subject_list = [val_file.replace("_val.csv", "") for val_file in filenames]
3737
accuracy, summary = {}, {}
3838

39-
# run_date = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime(time.time()))
4039
output_dir = args.output_dir
4140
save_result_dir = os.path.join(output_dir, f"take{take}")
4241
if not os.path.exists(save_result_dir):
4342
os.makedirs(save_result_dir, exist_ok=True)
4443

4544
all_answers = {}
4645
for index, subject_name in enumerate(subject_list):
47-
# print(
48-
# f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_name_or_path} with subject of {subject_name}!"
49-
# )
5046
val_file_path = os.path.join(val_path, f"{subject_name}_val.csv")
5147
dev_file_path = os.path.join(dev_path, f"{subject_name}_dev.csv")
5248
test_file_path = os.path.join(test_path, f"{subject_name}_test.csv")
5349

5450
val_df = pd.read_csv(val_file_path) if args.do_test is False else pd.read_csv(test_file_path)
5551
dev_df = pd.read_csv(dev_file_path) if args.few_shot else None
5652

57-
# import pdb;pdb.set_trace()
5853
correct_ratio, answers = evaluator.eval_subject(
5954
subject_name,
6055
val_df,

llm/experimental/ceval/default/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

llm/experimental/ceval/default/model_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

llm/experimental/layers/cache_kv.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -166,13 +166,7 @@ def forward(
166166

167167
def _smooth(self, x, y, use_smooth_x):
168168
# For ShiftSmooth
169-
# smooth_shape = y.shape[2:]
170169
self.dtype = y.dtype
171-
# if not hasattr(self, "smooth_weight"):
172-
# self.smooth_weight = self.create_parameter(
173-
# shape=smooth_shape,
174-
# attr=ParamAttr(initializer=Constant(value=1.)),
175-
# dtype=self.dtype)
176170
smooth_y = y
177171
smooth_y = paddle.divide(smooth_y, self.smooth_weight)
178172

llm/experimental/layers/custom_attention.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def forward(
5858
**kwargs
5959
):
6060
"""forward"""
61-
# import pdb;pdb.set_trace()
6261
if self.enable_fake_quant:
6362
self.collect_kv_quant_policy(q, k, v, **kwargs)
6463
perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3]

llm/experimental/observer/abs_max_headwise.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,8 +14,6 @@
1414

1515
import numpy as np
1616
import paddle
17-
18-
# from paddleslim.quant.observers.channel_wise import ChannelWiseObserver
1917
from experimental.observer.channel_wise import ChannelWiseObserver
2018
from paddle.quantization.factory import ObserverFactory
2119

llm/experimental/observer/avg_headwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.

llm/experimental/observer/channel_wise.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -14,13 +14,8 @@
1414

1515
from typing import Dict
1616

17-
# import numpy as np
1817
import paddle
19-
20-
# from paddle.quantization.factory import ObserverFactory
2118
from experimental.layers.cache_kv import CacheKVMatMul
22-
23-
# from paddleslim.quant.observers.mse import MSEObserverLayer
2419
from paddleslim.quant.observers.uniform import UniformObserver
2520

2621
CHANNEL_AXIS: Dict[type, int] = {

0 commit comments

Comments
 (0)