Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1902,7 +1902,7 @@ def test_duplicated_duplicated_constant_and_initializer(self):

model_proto = self.make_model(graph, producer_name="onnx-tests")
self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, op_type="Constant", remaining_op_num=0,
graph_validator=lambda g: self._check_initializer_num(g, 2))
graph_validator=lambda g: self._check_initializer_num(g, 1))

def test_duplicated_node_is_graph_output(self):
node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
Expand Down
14 changes: 8 additions & 6 deletions tf2onnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,9 +1791,11 @@ def _parse_graph_input(g, graph_proto, const_node_names):
# because for subgraphs, the input orders matter.
for graph_input in graph_proto.input:
name = graph_input.name
shape = shapes[name]
dtype = dtypes[name]
if name not in const_node_names:
g.add_graph_input(name, dtype, shape)
else:
g.add_graph_input_with_default(name, g.get_node_by_name(name), dtype, shape)
const_initializer_node = g.get_node_by_output_in_current_graph(name)
if const_initializer_node is None: # is actual input rather than initializer
shape = shapes[name]
dtype = dtypes[name]
if name not in const_node_names:
g.add_graph_input(name, dtype, shape)
else:
g.add_graph_input_with_default(name, g.get_node_by_name(name), dtype, shape)