@@ -570,20 +570,21 @@ def quantize_inputs(self, node, indices=None, initializer_use_weight_qType=True,
570570 if self .add_qdq_pair_to_weight and self .mode == "qdq" :
571571 weight = self ._get_quantized_weight (initializer , dtype , scheme )
572572 self ._update_weight (weight )
573+ node .input [idx ] = weight .name
573574 q_weight_name = weight .name + "_quantized"
574575 zp_name = weight .name + "_zero_point"
575576 scale_name = weight .name + "_scale"
576577 qlinear_node = make_quant_node (
577- tensor_name + "_QuantizeLinear" ,
578+ weight . name + "_QuantizeLinear" ,
578579 [tensor_name , scale_name , zp_name ],
579- [tensor_name + "_quantized" ],
580+ [weight . name + "_quantized" ],
580581 )
581582 dequant_node = make_dquant_node (
582- tensor_name + "_DequantizeLinear" ,
583- [tensor_name + "_quantized" , scale_name , zp_name ],
584- [tensor_name + "_dequantized" ],
583+ weight . name + "_DequantizeLinear" ,
584+ [weight . name + "_quantized" , scale_name , zp_name ],
585+ [weight . name + "_dequantized" ],
585586 )
586- self .replace_input .append ([node , tensor_name , dequant_node .output [0 ]])
587+ self .replace_input .append ([node , weight . name , dequant_node .output [0 ]])
587588 self .new_nodes .extend ([qlinear_node , dequant_node ])
588589 quantized_value = QuantizedValue (
589590 weight .name , q_weight_name , scale_name , zp_name , QuantizedValueType .Initializer , None , dtype
@@ -593,17 +594,18 @@ def quantize_inputs(self, node, indices=None, initializer_use_weight_qType=True,
593594 else :
594595 weight = self ._get_quantized_weight (initializer , dtype , scheme )
595596 self ._update_weight (weight )
597+ node .input [idx ] = weight .name
596598 q_weight_name = weight .name + "_quantized"
597599 zp_name = weight .name + "_zero_point"
598600 scale_name = weight .name + "_scale"
599601
600602 inputs = [q_weight_name , scale_name , zp_name ]
601603 output_name = tensor_name + "_DequantizeLinear"
602604 dequant_node = onnx .helper .make_node (
603- "DequantizeLinear" , inputs , [tensor_name + "_dequantized" ], tensor_name + "_DequantizeLinear"
605+ "DequantizeLinear" , inputs , [weight . name + "_dequantized" ], weight . name + "_DequantizeLinear"
604606 )
605607 self .new_nodes .append (dequant_node )
606- self .replace_input .append ([node , tensor_name , dequant_node .output [0 ]])
608+ self .replace_input .append ([node , weight . name , dequant_node .output [0 ]])
607609 quantized_value = QuantizedValue (
608610 weight .name , q_weight_name , scale_name , zp_name , QuantizedValueType .Initializer , None , dtype
609611 )
@@ -721,7 +723,8 @@ def quantize_bias_tensor(self, node):
721723 if len (beta_attribute ):
722724 beta = onnx .helper .get_attribute_value (beta_attribute [0 ])
723725 _ , quant_value = self .quantize_bias (bias_name , input_name , weight_name , beta )
724- self .model .remove_initializer (find_by_name (bias_name , self .model .initializer ()))
726+ if self .model .get_initializer_share_num (bias_name ) == 1 :
727+ self .model .remove_initializer (find_by_name (bias_name , self .model .initializer ()))
725728 inputs = [quant_value .q_name , quant_value .scale_name , quant_value .zp_name ]
726729 axis = None
727730 if find_by_name (weight_name + "_DequantizeLinear" , self .new_nodes ):
@@ -855,79 +858,96 @@ def quantize_weights_per_channel(self, node, indices, weight_qType, scheme, axis
855858 self .quantize_inputs (node , indices )
856859 return
857860
858- for idx , weight_name in enumerate (node .input ):
861+ for idx , inp in enumerate (node .input ):
859862 if idx not in indices :
860863 continue
861864
862865 if self .add_qdq_pair_to_weight and self .mode == "qdq" :
863- q_name , zp_name , scale_name = self .quantize_weight_per_channel (weight_name , weight_qType , scheme , axis )
866+ q_name , zp_name , scale_name = self .quantize_weight_per_channel (inp , weight_qType , scheme , axis )
867+ weight_name = (
868+ ("_" ).join ([inp , str (weight_qType )]) if self .model .get_initializer_share_num (inp ) > 1 else inp
869+ )
864870 qlinear_node = make_quant_node (
865- weight_name + "_QuantizeLinear" , [weight_name , scale_name , zp_name ], [weight_name + "_quantized" ]
871+ weight_name + "_QuantizeLinear" ,
872+ [inp , scale_name , zp_name ],
873+ [q_name ],
874+ axis ,
866875 )
867876 dequant_node = make_dquant_node (
868877 weight_name + "_DequantizeLinear" ,
869- [weight_name + "_quantized" , scale_name , zp_name ],
878+ [q_name , scale_name , zp_name ],
870879 [weight_name + "_dequantized" ],
871880 axis ,
872881 )
882+ node .input [idx ] = weight_name
873883 self .replace_input .append ([node , weight_name , dequant_node .output [0 ]])
874884 self .new_nodes .extend ([qlinear_node , dequant_node ])
875885 else :
876- q_name , zp_name , scale_name = self .quantize_weight_per_channel (weight_name , weight_qType , scheme , axis )
877- inputs = [q_name , scale_name , zp_name ]
886+ q_name , zp_name , scale_name = self .quantize_weight_per_channel (inp , weight_qType , scheme , axis )
887+ weight_name = (
888+ ("_" ).join ([inp , str (weight_qType )]) if self .model .get_initializer_share_num (inp ) > 1 else inp
889+ )
878890 dequant_node = make_dquant_node (
879891 weight_name + "_DequantizeLinear" ,
880892 [q_name , scale_name , zp_name ],
881893 [weight_name + "_dequantized" ],
882894 axis ,
883895 )
884896 self .new_nodes .append (dequant_node )
897+ node .input [idx ] = weight_name
885898
886899 # Replace weight_name with output of DequantizeLinear
887900 self .replace_input .append ([node , weight_name , dequant_node .output [0 ]])
888901
889902 def quantize_weight_per_channel (self , weight_name , weight_qType , scheme , channel_axis ):
890903 """Quantize weight per-channel."""
904+ name = (
905+ ("_" ).join ([weight_name , str (weight_qType )])
906+ if self .model .get_initializer_share_num (weight_name ) > 1
907+ else weight_name
908+ )
909+ if name in self .quantized_value_map :
910+ return (name + "_quantized" , name + "_zero_point" , name + "_scale" )
911+
891912 initializer = find_by_name (weight_name , self .model .initializer ())
892913 if initializer is None :
893914 raise ValueError ("{} is not an initializer" , weight_name )
894915
895- if initializer .name not in self .quantized_value_map :
896- weights = (
897- self .tensor_proto_to_array (initializer , os .path .dirname (self .model .model_path ))
898- if self .model .model_path is not None
899- else self .tensor_proto_to_array (initializer )
900- )
901- rmin , rmax , zero_point , scale , quantized_weights = quantize_data_per_channel (
902- weights , channel_axis , _get_qrange_for_qType (weight_qType , self .reduce_range ), weight_qType , scheme
903- )
916+ weights = (
917+ self .tensor_proto_to_array (initializer , os .path .dirname (self .model .model_path ))
918+ if self .model .model_path is not None
919+ else self .tensor_proto_to_array (initializer )
920+ )
921+ rmin , rmax , zero_point , scale , quantized_weights = quantize_data_per_channel (
922+ weights , channel_axis , _get_qrange_for_qType (weight_qType , self .reduce_range ), weight_qType , scheme
923+ )
904924
905- weight = QuantizedInitializer (
906- initializer . name ,
907- initializer ,
908- rmin ,
909- rmax ,
910- zero_point ,
911- scale ,
912- weights ,
913- quantized_weights .flatten ().tolist (),
914- channel_axis ,
915- weight_qType ,
916- )
925+ weight = QuantizedInitializer (
926+ name ,
927+ initializer ,
928+ rmin ,
929+ rmax ,
930+ zero_point ,
931+ scale ,
932+ weights ,
933+ quantized_weights .flatten ().tolist (),
934+ channel_axis ,
935+ weight_qType ,
936+ )
917937
918- self ._update_weight (weight )
919- quantized_value = QuantizedValue (
920- weight .name ,
921- weight .name + "_quantized" ,
922- weight .name + "_scale" ,
923- weight .name + "_zero_point" ,
924- QuantizedValueType .Initializer ,
925- None ,
926- weight_qType ,
927- )
928- self .quantized_value_map [weight .name ] = quantized_value
938+ self ._update_weight (weight )
939+ quantized_value = QuantizedValue (
940+ weight .name ,
941+ weight .name + "_quantized" ,
942+ weight .name + "_scale" ,
943+ weight .name + "_zero_point" ,
944+ QuantizedValueType .Initializer ,
945+ None ,
946+ weight_qType ,
947+ )
948+ self .quantized_value_map [weight .name ] = quantized_value
929949
930- return (initializer .name + "_quantized" , initializer .name + "_zero_point" , initializer .name + "_scale" )
950+ return (weight .name + "_quantized" , weight .name + "_zero_point" , weight .name + "_scale" )
931951
932952 def _update_weight (self , weight ):
933953 """Update weight.
@@ -1018,8 +1038,13 @@ def _get_quantization_params(self, param_name):
10181038
10191039 def _get_quantized_weight (self , initializer , qType , scheme ):
10201040 """Get quantized weight."""
1021- if initializer .name in self .quantized_value_map :
1022- return self .quantized_value_map [initializer .name ]
1041+ name = (
1042+ ("_" ).join ([initializer .name , str (qType )])
1043+ if self .model .get_initializer_share_num (initializer .name ) > 1
1044+ else initializer .name
1045+ )
1046+ if name in self .quantized_value_map :
1047+ return self .quantized_value_map [name ]
10231048 weights_data = (
10241049 self .tensor_proto_to_array (initializer , os .path .dirname (self .model .model_path ))
10251050 if self .model .model_path is not None
@@ -1029,7 +1054,7 @@ def _get_quantized_weight(self, initializer, qType, scheme):
10291054 weights_data .flatten ().tolist (), _get_qrange_for_qType (qType , self .reduce_range ), qType , scheme
10301055 )
10311056 weight = QuantizedInitializer (
1032- initializer . name ,
1057+ name ,
10331058 initializer ,
10341059 [rmin ],
10351060 [rmax ],
0 commit comments