@@ -255,7 +255,7 @@ def should_merge_last_two_dim(self) -> bool:
255255 """check that wether merge last two dim"""
256256 return self .action == "merge_last_two_dim"
257257
258- def run (self , tensor : ndarray ) -> ndarray :
258+ def run (self , state_dict : dict [ str , ndarray ], name : str ) -> ndarray :
259259 """run some custom operation on ndarray, eg: transpose, merge_last_two_dim
260260
261261 Args:
@@ -264,12 +264,21 @@ def run(self, tensor: ndarray) -> ndarray:
264264 Returns:
265265 ndarray: the final tensor
266266 """
267+ tensor = state_dict .pop (name )
267268 if self .action == "transpose" :
268269 return transpose (tensor , [1 , 0 ])
269270 if self .action == "merge_last_two_dim" :
270271 shape = tensor .shape
271272 assert len (shape ) == 3
272273 return np .reshape (tensor , [shape [0 ], - 1 ])
274+ if self .action == "split" :
275+ assert self .index is not None , "when action is `split`, index field is required."
276+
277+ if self .index < 2 :
278+ state_dict [name ] = tensor
279+ # qkv is stored in same tensor, so it should be split into 3 arr
280+ tensors = np .split (tensor , 3 , axis = - 1 )
281+ return tensors [self .index ]
273282 return tensor
274283
275284 def matched (self , text : str ) -> bool :
@@ -490,6 +499,9 @@ class LogitComparer:
490499 config_fields_to_be_removed : List [str ] = ["transformers_version" ]
491500 architectures : Dict [str , Type [PretrainedModel ]] = {}
492501
502+ def __init__ (self , input_dir : str ) -> None :
503+ self .input_dir = input_dir
504+
493505 def get_paddle_pytorch_model_classes (self ) -> Tuple [object , object ]:
494506 """return the [PaddleModelClass, PytorchModelClass] to
495507 1. generate paddle model automatically
@@ -574,13 +586,15 @@ def compare_model_state_dicts(
574586 for name_mapping in name_mappings :
575587 model_state_saver .add (name_mapping .target_name , "pytorch_key" , name_mapping .source_name )
576588
577- paddle_numpy = paddle_state_dict .pop (name_mapping .target_name )
578- model_state_saver .add (name_mapping .target_name , "paddle" , paddle_numpy )
579- model_state_saver .add (name_mapping .target_name , "paddle-shape" , str (paddle_numpy .shape ))
589+ if name_mapping .target_name in paddle_state_dict :
590+ paddle_numpy = paddle_state_dict .pop (name_mapping .target_name )
591+ model_state_saver .add (name_mapping .target_name , "paddle" , paddle_numpy )
592+ model_state_saver .add (name_mapping .target_name , "paddle-shape" , str (paddle_numpy .shape ))
580593
581- pytorch_numpy = pytorch_state_dict .pop (name_mapping .source_name )
582- model_state_saver .add (name_mapping .target_name , "pytorch" , pytorch_numpy )
583- model_state_saver .add (name_mapping .target_name , "pytorch-shape" , str (pytorch_numpy .shape ))
594+ if name_mapping .source_name in pytorch_state_dict :
595+ pytorch_numpy = pytorch_state_dict .pop (name_mapping .source_name )
596+ model_state_saver .add (name_mapping .target_name , "pytorch" , pytorch_numpy )
597+ model_state_saver .add (name_mapping .target_name , "pytorch-shape" , str (pytorch_numpy .shape ))
584598
585599 model_state_saver .summary ()
586600
@@ -594,8 +608,7 @@ def compare_logits(self) -> bool:
594608 paddle_model = PaddleModel .from_pretrained (self .input_dir )
595609
596610 # 0. init the name_mapping & tensor_info_saver & logit_hooker
597- num_layers = self .get_num_layer (list (paddle_model .state_dict ().keys ()))
598- name_mappings = self .get_name_mapping (num_layers , paddle_model .config ["architectures" ])
611+ name_mappings = self .get_name_mapping (paddle_model .config )
599612 tensor_info_saver = TensorInfoSaver ()
600613
601614 logit_hooker = LogitHooker (name_mappings , tensor_info_saver )
@@ -707,8 +720,9 @@ def convert(cls, weight_file: str, config: PretrainedConfig, cache_dir: str) ->
707720 logger .warning (f"key<{ name_mapping .source_name } > not in the pytorch weight file." )
708721 continue
709722
710- state_dict [name_mapping .target_name ] = name_mapping .run (state_dict .pop (name_mapping .source_name ))
711- all_layer_names .remove (name_mapping .source_name )
723+ state_dict [name_mapping .target_name ] = name_mapping .run (state_dict , name_mapping .source_name )
724+ if name_mapping .source_name in all_layer_names :
725+ all_layer_names .remove (name_mapping .source_name )
712726
713727 if all_layer_names :
714728 logger .warning (f"there are { len (all_layer_names )} tensors not initialized:" )
0 commit comments