Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ class QuantizationTransformPass(object):
the quantized ops's inputs.
"""
_supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul'
'conv2d', 'depthwise_conv2d', 'conv2d_transpose', 'mul', 'matmul',
'matmul_v2'
]

def __init__(self,
Expand Down Expand Up @@ -520,6 +521,16 @@ def _transform_backward(graph, op):
dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op)

def _has_weight(op):
has_weight = False
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
name = var_node.name()
if var_node.name() in persistable_vars:
has_weight = True
return has_weight

if not self._is_test:
self._create_global_step(graph)
ops = graph.all_op_nodes()
Expand All @@ -535,11 +546,11 @@ def _transform_backward(graph, op):
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
if not self._is_skip_quant(graph, op):
if not self._is_skip_quant(graph, op) and _has_weight(op):
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops:
if op.name() in self._quantizable_grad_ops and _has_weight(op):
_transform_backward(graph, op)
graph.resolve_hazard()
return graph
Expand Down