Skip to content

Commit da8092a

Browse files
committed
Add unitest for customized quanter
1 parent fd74346 commit da8092a

File tree

3 files changed

+76
-10
lines changed

3 files changed

+76
-10
lines changed

python/paddle/quantization/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(self):
127127
_element, activation=activation, weight=weight
128128
)
129129
else:
130-
self.add_prefix_config(
130+
self.add_name_config(
131131
layer.full_name(), activation=activation, weight=weight
132132
)
133133

@@ -170,7 +170,7 @@ def __init__(self):
170170
self._prefix2config[layer_name] = config
171171
if isinstance(layer_name, list):
172172
for _element in layer_name:
173-
self.add_prefix_config(
173+
self.add_name_config(
174174
_element, activation=activation, weight=weight
175175
)
176176

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# copyright (c) 2022 paddlepaddle authors. all rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from typing import Iterable, Union
17+
18+
import numpy as np
19+
20+
import paddle
21+
from paddle.nn import Linear
22+
from paddle.quantization.base_quanter import BaseQuanter
23+
from paddle.quantization.factory import quanter
24+
25+
linear_quant_axis = 1
26+
27+
28+
@quanter("CustomizedQuanter")
29+
class CustomizedQuanterLayer(BaseQuanter):
30+
def __init__(self, layer, bit_length=8, kwargs1=None):
31+
super(CustomizedQuanterLayer, self).__init__()
32+
self._layer = layer
33+
self._bit_length = bit_length
34+
self._kwargs1 = kwargs1
35+
36+
def scales(self) -> Union[paddle.Tensor, np.ndarray]:
37+
return None
38+
39+
def bit_length(self):
40+
return self._bit_length
41+
42+
def quant_axis(self) -> Union[int, Iterable]:
43+
return linear_quant_axis if isinstance(self._layer, Linear) else None
44+
45+
def zero_points(self) -> Union[paddle.Tensor, np.ndarray]:
46+
return None
47+
48+
def forward(self, input):
49+
return input
50+
51+
52+
class TestCustomizedQuanter(unittest.TestCase):
53+
def test_details(self):
54+
layer = Linear(5, 5)
55+
bit_length = 4
56+
quanter = CustomizedQuanter( # noqa: F821
57+
bit_length=bit_length, kwargs1="test"
58+
)
59+
quanter = quanter._instance(layer)
60+
self.assertEqual(quanter.bit_length(), bit_length)
61+
self.assertEqual(quanter.quant_axis(), linear_quant_axis)
62+
self.assertEqual(quanter._kwargs1, 'test')
63+
64+
65+
if __name__ == '__main__':
66+
unittest.main()

python/paddle/tests/quantization/test_quant.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import paddle.nn.functional as F
1919
from paddle.nn import Conv2D, Linear, ReLU, Sequential
2020
from paddle.quantization import QuantConfig
21+
from paddle.quantization.base_quanter import BaseQuanter
2122
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
22-
from python.paddle.quantization.base_quanter import BaseQuanter
2323

2424

2525
class LeNetDygraph(paddle.nn.Layer):
@@ -86,9 +86,9 @@ def test_add_layer_config(self):
8686
self.q_config._specify(self.model)
8787
self.assert_just_linear_weight_configure(self.model, self.q_config)
8888

89-
def test_add_prefix_config(self):
89+
def test_add_name_config(self):
9090
self.q_config = QuantConfig(activation=None, weight=None)
91-
self.q_config.add_prefix_config(
91+
self.q_config.add_name_config(
9292
[self.model.fc.full_name()], activation=None, weight=self.quanter
9393
)
9494
self.q_config._specify(self.model)
@@ -110,11 +110,11 @@ def test_add_qat_layer_mapping(self):
110110
Sequential not in self.q_config.default_qat_layer_mapping
111111
)
112112

113-
def test_add_custom_leaf(self):
113+
def test_add_customized_leaf(self):
114114
self.q_config = QuantConfig(activation=None, weight=None)
115-
self.q_config.add_custom_leaf(Sequential)
116-
self.assertTrue(Sequential in self.q_config.custom_leaves)
117-
self.assertTrue(self.q_config._is_custom_leaf(self.model.fc))
115+
self.q_config.add_customized_leaf(Sequential)
116+
self.assertTrue(Sequential in self.q_config.customized_leaves)
117+
self.assertTrue(self.q_config._is_customized_leaf(self.model.fc))
118118
self.assertTrue(self.q_config._is_leaf(self.model.fc))
119119
self.assertFalse(self.q_config._is_default_leaf(self.model.fc))
120120
self.assertFalse(self.q_config._is_real_leaf(self.model.fc))
@@ -124,7 +124,7 @@ def test_need_observe(self):
124124
self.q_config.add_layer_config(
125125
[self.model.fc], activation=self.quanter, weight=self.quanter
126126
)
127-
self.q_config.add_custom_leaf(Sequential)
127+
self.q_config.add_customized_leaf(Sequential)
128128
self.q_config._specify(self.model)
129129
self.assertTrue(self.q_config._has_observer_config(self.model.fc))
130130
self.assertTrue(self.q_config._need_observe(self.model.fc))

0 commit comments

Comments
 (0)