4040
4141logger = Logger ().get_logger ()
4242
43- from typing import Any , Callable , List , Optional , Tuple , Union
43+ from typing import Any , List , Optional , Tuple , Union
4444
4545from neural_compressor .common .base_config import (
4646 BaseConfig ,
47- ComposableConfig ,
47+ config_registry ,
4848 get_all_config_set_from_config_registry ,
4949 register_config ,
50+ register_supported_configs_for_fwk ,
5051)
5152from neural_compressor .common .base_tuning import ConfigLoader , ConfigSet , SequentialSampler
5253from neural_compressor .common .tuning_param import TuningParam
5354from neural_compressor .common .utils import DEFAULT_WHITE_LIST , OP_NAME_OR_MODULE_TYPE
5455
5556PRIORITY_FAKE_ALGO = 100
5657FAKE_CONFIG_NAME = "fake"
58+ PRIORITY_FAKE_ALGO_1 = 90
59+ FAKE_CONFIG_NAME_1 = "fake_one"
5760DEFAULT_WEIGHT_BITS = [4 , 6 ]
5861
5962FAKE_FRAMEWORK_NAME = "FAKE_FWK"
6063
64+ FAKE_MODEL_INFO = [("OP1_NAME" , "OP_TYPE1" ), ("OP2_NAME" , "OP_TYPE1" ), ("OP3_NAME" , "OP_TYPE2" )]
65+
66+
67+ class FakeModel :
68+ def __init__ (self ) -> None :
69+ self .name = "fake_model"
70+
71+ def __call__ (self , x ) -> Any :
72+ return x
73+
74+ def __repr__ (self ) -> str :
75+ return "FakeModel"
76+
6177
6278@register_config (framework_name = FAKE_FRAMEWORK_NAME , algo_name = FAKE_CONFIG_NAME , priority = PRIORITY_FAKE_ALGO )
6379class FakeAlgoConfig (BaseConfig ):
@@ -102,17 +118,14 @@ def register_supported_configs(cls) -> List:
102118 pass
103119
104120 @staticmethod
105- def get_model_info (model : Any ) -> List [Tuple [str , Callable ]]:
106- pass
121+ def get_model_info (model : Any ) -> List [Tuple [str , Any ]]:
122+ return FAKE_MODEL_INFO
107123
108124 @classmethod
109125 def get_config_set_for_tuning (cls ) -> Union [None , "FakeAlgoConfig" , List ["FakeAlgoConfig" ]]:
110126 return FakeAlgoConfig (weight_bits = DEFAULT_WEIGHT_BITS )
111127
112128
113- FakeAlgoConfig .register_supported_configs ()
114-
115-
116129def get_default_fake_config () -> FakeAlgoConfig :
117130 """Generate the default fake config.
118131
@@ -122,10 +135,64 @@ def get_default_fake_config() -> FakeAlgoConfig:
122135 return FakeAlgoConfig ()
123136
124137
138+ @register_config (framework_name = FAKE_FRAMEWORK_NAME , algo_name = FAKE_CONFIG_NAME_1 , priority = PRIORITY_FAKE_ALGO_1 )
139+ class FakeAlgoOneConfig (BaseConfig ):
140+ """Config class for fake algo."""
141+
142+ supported_configs : List = []
143+ params_list = [
144+ "weight_dtype" ,
145+ "weight_bits" ,
146+ TuningParam ("target_op_type_list" , tunable_type = List [List [str ]]),
147+ ]
148+ name = FAKE_CONFIG_NAME_1
149+
150+ def __init__ (
151+ self ,
152+ weight_dtype : str = "int" ,
153+ weight_bits : int = 4 ,
154+ target_op_type_list : List [str ] = ["Conv" , "Gemm" ],
155+ white_list : Optional [List [OP_NAME_OR_MODULE_TYPE ]] = DEFAULT_WHITE_LIST ,
156+ ):
157+ """Init fake config.
158+
159+ Args:
160+ weight_dtype (str): Data type for weights, default is "int".
161+ weight_bits (int): Number of bits used to represent weights, default is 4.
162+ """
163+ super ().__init__ (white_list = white_list )
164+ self .weight_bits = weight_bits
165+ self .weight_dtype = weight_dtype
166+ self .target_op_type_list = target_op_type_list
167+ self ._post_init ()
168+
169+ def to_dict (self ):
170+ return super ().to_dict ()
171+
172+ @classmethod
173+ def from_dict (cls , config_dict ):
174+ return super (FakeAlgoOneConfig , cls ).from_dict (config_dict = config_dict )
175+
176+ @classmethod
177+ def register_supported_configs (cls ) -> List :
178+ pass
179+
180+ @staticmethod
181+ def get_model_info (model : Any ) -> List [Tuple [str , Any ]]:
182+ return FAKE_MODEL_INFO
183+
184+ @classmethod
185+ def get_config_set_for_tuning (cls ) -> Union [None , "FakeAlgoOneConfig" , List ["FakeAlgoOneConfig" ]]:
186+ return FakeAlgoOneConfig (weight_bits = DEFAULT_WEIGHT_BITS )
187+
188+
125189def get_all_config_set () -> Union [BaseConfig , List [BaseConfig ]]:
126190 return get_all_config_set_from_config_registry (fwk_name = FAKE_FRAMEWORK_NAME )
127191
128192
193+ register_supported_configs_for_fwk (fwk_name = FAKE_FRAMEWORK_NAME )
194+
195+
129196class TestBaseConfig (unittest .TestCase ):
130197 @classmethod
131198 def setUpClass (self ):
@@ -143,7 +210,7 @@ def test_api(self):
143210 fake_default_config = get_default_fake_config ()
144211 self .assertEqual (fake_default_config .weight_dtype , "int" )
145212 config_set = get_all_config_set ()
146- self .assertEqual (len (config_set ), 1 )
213+ self .assertEqual (len (config_set ), len ( config_registry . get_all_config_cls_by_fwk_name ( FAKE_FRAMEWORK_NAME )) )
147214 self .assertEqual (config_set [0 ].weight_bits , DEFAULT_WEIGHT_BITS )
148215
149216 def test_config_expand_complex_tunable_type (self ):
@@ -154,6 +221,18 @@ def test_config_expand_complex_tunable_type(self):
154221 for i in range (len (configs_list )):
155222 self .assertEqual (configs_list [i ].target_op_type_list , target_op_type_list_options [i ])
156223
224+ def test_mixed_two_algos (self ):
225+ model = FakeModel ()
226+ OP1_NAME = "OP1_NAME"
227+ OP2_NAME = "OP2_NAME"
228+ fake_config = FakeAlgoConfig (weight_bits = 4 , white_list = [OP1_NAME ])
229+ fake1_config = FakeAlgoOneConfig (weight_bits = 2 , white_list = [OP2_NAME ])
230+ mixed_config = fake_config + fake1_config
231+ model_info = mixed_config .get_model_info (model )
232+ config_mapping = mixed_config .to_config_mapping (model_info = model_info )
233+ self .assertIn (OP1_NAME , [op_info [0 ] for op_info in config_mapping ])
234+ self .assertIn (OP2_NAME , [op_info [0 ] for op_info in config_mapping ])
235+
157236
158237class TestConfigSet (unittest .TestCase ):
159238 def setUp (self ):
0 commit comments