Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tf2onnx/optimizer/transpose_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _calculate_new_shape(graph, op):
new_shape = [input_shape[p] for p in perm]
return graph.make_const(utils.make_name("new_shape"), np.array(new_shape, dtype=np.int64)).output[0]

# reshape requires tha output shape can only contain one -1, if not some extra op needed.
# reshape requires the output shape can only contain one -1, if not some extra op needed.
input_shape = graph.make_node("Shape", [op.input[0]]).output[0]
indice = graph.make_const(utils.make_name("indice"), np.array(perm, np.int64)).output[0]

Expand Down Expand Up @@ -668,9 +668,12 @@ def _concat_handler(self, trans, node):
return False

def _split_handler(self, trans, node):
# Todo: need handle cases where Slit node has more than 1 outputs.
# Todo: need handle cases where Split node has more than 1 outputs.
if self._handle_node_having_branches(trans, node):
node.set_attr("axis", 1)
perm = trans.get_attr_value("perm")
axis = node.get_attr_value("axis", 0)
new_axis = perm[axis]
node.set_attr("axis", new_axis)
return True
return False

Expand Down