Skip to content

Commit bf4a22d

Browse files
authored
feat: For QDQ models, disable constant folding and ensure the conv weights are transposed without using an explicit op (#1856)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 22db2a9 commit bf4a22d

File tree

5 files changed

+38
-12
lines changed

5 files changed

+38
-12
lines changed

tf2onnx/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def default_custom_op_handler(ctx, node, name, args):
138138

139139

140140
def _convert_common(frozen_graph, name="unknown", large_model=False, output_path=None,
141-
output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, **kwargs):
141+
output_frozen_graph=None, custom_ops=None, custom_op_handlers=None, optimizers=None, **kwargs):
142142
"""Common processing for conversion."""
143143

144144
model_proto = None
@@ -165,7 +165,7 @@ def _convert_common(frozen_graph, name="unknown", large_model=False, output_path
165165
catch_errors = constants.ENV_TF2ONNX_CATCH_ERRORS.upper() == "TRUE"
166166
else:
167167
catch_errors = not large_model
168-
onnx_graph = optimizer.optimize_graph(g, catch_errors)
168+
onnx_graph = optimizer.optimize_graph(g, catch_errors, optimizers=optimizers)
169169
model_proto = onnx_graph.make_model("converted from {}".format(name),
170170
external_tensor_storage=external_tensor_storage)
171171
if output_path:

tf2onnx/onnx_opset/nn.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,25 @@ def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
139139

140140
# If kernel is a constant, transpose that one if we are the only consumer.
141141
need_transpose = True
142-
if kernel_node.is_const() and len(ctx.find_output_consumers(kernel_name)) == 1:
143-
val = kernel_node.get_tensor_value(as_list=False)
144-
val = np.transpose(val, permutation)
145-
146-
kernel_node.set_tensor_value(val)
147-
need_transpose = False
142+
if (kernel_node.is_const() or kernel_node.op.op_type == "DequantizeLinear") \
143+
and len(ctx.find_output_consumers(kernel_name)) == 1:
144+
if kernel_node.op.op_type == 'DequantizeLinear':
145+
# Assuming the model was trained in NHWC in TF,
146+
# the weights would be in [fH, fW, C_in, C_out].
147+
# orig_conv_weights -> Q -> DQ -> new_conv_weights -> conv
148+
weights_node = kernel_node.inputs[0].inputs[0]
149+
val = weights_node.get_tensor_value(as_list=False)
150+
val = np.transpose(val, permutation)
151+
weights_node.set_tensor_value(val)
152+
need_transpose = False
153+
# Change the quantization axis for Q and DQ node accordingly
154+
kernel_node.set_attr("axis", 0) # DQ node
155+
kernel_node.inputs[0].set_attr("axis", 0) # Q node
156+
else:
157+
val = kernel_node.get_tensor_value(as_list=False)
158+
val = np.transpose(val, permutation)
159+
kernel_node.set_tensor_value(val)
160+
need_transpose = False
148161

149162
if need_transpose:
150163
transpose = ctx.insert_new_node_on_input(node, "Transpose", kernel_name)

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
"""
5-
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op
5+
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3|QuantizeAndDequantizeV4 op
66
"""
77

88
import numpy as np
@@ -25,7 +25,8 @@ def create_qdq_nodes(g, match_results):
2525
# Get the attributes of qdq node
2626
narrow_range = qdq_node.attr['narrow_range'].i
2727
signed_input = qdq_node.attr['signed_input'].i
28-
range_given = qdq_node.get_attr_value("range_given", qdq_node.type != "QuantizeAndDequantizeV2")
28+
range_given = qdq_node.get_attr_value("range_given", qdq_node.type != "QuantizeAndDequantizeV2" or \
29+
qdq_node.type != "QuantizeAndDequantizeV4")
2930

3031
min_quantized, max_quantized = [-127, 127]
3132
if not narrow_range and signed_input:
@@ -147,9 +148,15 @@ def rewrite_quantize_and_dequantize(g, ops):
147148
OpTypePattern(None),
148149
OpTypePattern(None),
149150
])
151+
pattern_for_qdq_v4 = \
152+
OpTypePattern('QuantizeAndDequantizeV4', name='output', inputs=[
153+
OpTypePattern("*"),
154+
OpTypePattern(None),
155+
OpTypePattern(None),
156+
])
150157

151158
# Match all the patterns for QDQ ops
152-
patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
159+
patterns = [pattern_for_qdq_v2, pattern_for_qdq_v3, pattern_for_qdq_v4]
153160
match_results = []
154161
for pattern in patterns:
155162
matcher = GraphMatcher(pattern)

tf2onnx/tf_loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,11 @@ def tf_optimize_grappler(input_names, output_names, graph_def, fold_constant=Non
686686
# 'pruning', 'constfold', 'arithmetic', 'dependency', 'function',
687687
'constfold', 'function'
688688
]
689+
690+
if LooseVersion(tf.__version__) >= "2.5":
691+
# This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights)
692+
rewrite_options.experimental_disable_folding_quantization_emulation = True
693+
689694
meta_graph = tf.compat.v1.train.export_meta_graph(graph_def=graph_def)
690695
fetch_collection = meta_graph_pb2.CollectionDef()
691696
for t in input_names + output_names:

tf2onnx/tf_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,8 @@ def is_huge_shape(x):
219219
outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype
220220
progress = True
221221
can_fold = node.type not in ['Enter', 'Placeholder', 'PlaceholderWithDefault', 'Switch', 'Merge',
222-
'NextIteration', 'Exit']
222+
'NextIteration', 'Exit', 'QuantizeAndDequantizeV2', 'QuantizeAndDequantizeV3',
223+
'QuantizeAndDequantizeV4']
223224
can_fold = can_fold and not node.type.startswith('Random')
224225
can_fold = can_fold and len(input_names) > 0 and all(inp in outputs_to_values for inp in input_names)
225226
# We can only fold nodes with a single output

0 commit comments

Comments
 (0)