Skip to content

Commit 476c57b

Browse files
committed
fix pylint
Signed-off-by: Salvetti, Francesco <[email protected]>
1 parent b30801c commit 476c57b

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2366,8 +2366,9 @@ def version_10(cls, ctx, node, **kwargs):
23662366
@tf_op("Unique", onnx_op="Unique")
23672367
@tf_op("UniqueWithCounts", onnx_op="Unique")
23682368
class Unique:
2369-
int_cast = [TensorProto.BOOL, TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8, TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64]
2370-
dtype_map = {k:TensorProto.INT64 for k in int_cast}
2369+
int_cast = [TensorProto.BOOL, TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8,
2370+
TensorProto.UINT16, TensorProto.UINT32, TensorProto.UINT64]
2371+
dtype_map = {k: TensorProto.INT64 for k in int_cast}
23712372
dtype_map[TensorProto.DOUBLE] = TensorProto.FLOAT
23722373

23732374
@classmethod
@@ -2379,22 +2380,23 @@ def version_11(cls, ctx, node, **kwargs):
23792380
inp_dtype = ctx.get_dtype(node.input[0])
23802381

23812382
ctx.remove_node(node_name)
2382-
2383-
# due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
2383+
2384+
# due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
23842385
if inp_dtype in cls.dtype_map:
23852386
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': cls.dtype_map[inp_dtype]}).output[0]
23862387
node_inputs[0] = inp_cast
2387-
2388+
23882389
new_node = ctx.make_node("Unique", node_inputs, name=node_name, attr={'sorted': 0},
2389-
outputs=[utils.make_name("y"), utils.make_name("idx_first"), utils.make_name("idx"), utils.make_name("counts")])
2390+
outputs=[utils.make_name("y"), utils.make_name("idx_first"),
2391+
utils.make_name("idx"), utils.make_name("counts")])
23902392
ctx.replace_all_inputs(node_outputs[0], new_node.output[0])
23912393
ctx.replace_all_inputs(node_outputs[1], new_node.output[2])
2392-
if len(node_outputs)==3: # we need counts too (UniqueWithCounts)
2394+
if len(node_outputs) == 3: # we need counts too (UniqueWithCounts)
23932395
ctx.replace_all_inputs(node_outputs[2], new_node.output[3])
23942396
if ctx.get_dtype(new_node.output[0]) != inp_dtype:
2395-
ctx.insert_new_node_on_output("Cast", new_node.output[0], name=utils.make_name(node.name) + "_cast",
2396-
to=inp_dtype)
2397-
2397+
ctx.insert_new_node_on_output("Cast", new_node.output[0], to=inp_dtype,
2398+
name=utils.make_name(node.name) + "_cast")
2399+
23982400
# cast idx and counts if needed
23992401
out_dtype = node.get_attr_value('out_idx')
24002402
if out_dtype != TensorProto.INT64:

0 commit comments

Comments
 (0)