@@ -118,12 +118,16 @@ def from_dict(cls, config_dict, str2operator=None):
118118 Returns:
119119 The constructed config.
120120 """
121- config = cls (** config_dict .get (GLOBAL , {}))
122- operator_config = config_dict .get (LOCAL , {})
123- if operator_config :
124- for op_name , op_config in operator_config .items ():
125- config .set_local (op_name , cls (** op_config ))
126- return config
121+ if GLOBAL not in config_dict and LOCAL not in config_dict :
122+ config = cls (** config_dict )
123+ return config
124+ else :
125+ config = cls (** config_dict .get (GLOBAL , {}))
126+ operator_config = config_dict .get (LOCAL , {})
127+ if operator_config :
128+ for op_name , op_config in operator_config .items ():
129+ config .set_local (op_name , cls (** op_config ))
130+ return config
127131
128132 @classmethod
129133 def to_diff_dict (cls , instance ) -> Dict [str , Any ]:
@@ -201,11 +205,11 @@ def to_config_mapping(
201205 global_config = config .global_config
202206 op_type_config_dict , op_name_config_dict = config ._get_op_name_op_type_config ()
203207 for op_name , op_type in model_info :
204- config_mapping . setdefault (op_type , OrderedDict ())[ op_name ] = global_config
208+ config_mapping [ (op_type , op_name ) ] = global_config
205209 if op_type in op_type_config_dict :
206- config_mapping [op_type ][ op_name ] = op_name_config_dict [op_type ]
210+ config_mapping [( op_type , op_name ) ] = op_name_config_dict [op_type ]
207211 if op_name in op_name_config_dict :
208- config_mapping [op_type ][ op_name ] = op_name_config_dict [op_name ]
212+ config_mapping [( op_type , op_name ) ] = op_name_config_dict [op_name ]
209213 return config_mapping
210214
211215 @staticmethod
@@ -234,9 +238,15 @@ def to_dict(self, params_list=[], operator2str=None):
234238 return result
235239
236240 @classmethod
237- def from_dict (cls , config_dict , str2operator = None ):
238- # TODO(Yi)
239- pass
241+ def from_dict (cls , config_dict : OrderedDict [str , Dict ], config_registry : Dict [str , BaseConfig ]):
242+ assert len (config_dict ) >= 1 , "The config dict must include at least one configuration."
243+ num_configs = len (config_dict )
244+ name , value = next (iter (config_dict .items ()))
245+ config = config_registry [name ].from_dict (value )
246+ for _ in range (num_configs - 1 ):
247+ name , value = next (iter (config_dict .items ()))
248+ config += config_registry [name ].from_dict (value )
249+ return config
240250
241251 def to_json_string (self , use_diff : bool = False ) -> str :
242252 return json .dumps (self .to_dict (), indent = 2 ) + "\n "
0 commit comments