@@ -2366,8 +2366,9 @@ def version_10(cls, ctx, node, **kwargs):
2366
2366
@tf_op ("Unique" , onnx_op = "Unique" )
2367
2367
@tf_op ("UniqueWithCounts" , onnx_op = "Unique" )
2368
2368
class Unique :
2369
- int_cast = [TensorProto .BOOL , TensorProto .INT32 , TensorProto .INT16 , TensorProto .UINT8 , TensorProto .UINT16 , TensorProto .UINT32 , TensorProto .UINT64 ]
2370
- dtype_map = {k :TensorProto .INT64 for k in int_cast }
2369
+ int_cast = [TensorProto .BOOL , TensorProto .INT32 , TensorProto .INT16 , TensorProto .UINT8 ,
2370
+ TensorProto .UINT16 , TensorProto .UINT32 , TensorProto .UINT64 ]
2371
+ dtype_map = {k : TensorProto .INT64 for k in int_cast }
2371
2372
dtype_map [TensorProto .DOUBLE ] = TensorProto .FLOAT
2372
2373
2373
2374
@classmethod
@@ -2379,22 +2380,23 @@ def version_11(cls, ctx, node, **kwargs):
2379
2380
inp_dtype = ctx .get_dtype (node .input [0 ])
2380
2381
2381
2382
ctx .remove_node (node_name )
2382
-
2383
- # due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
2383
+
2384
+ # due to ORT missing implementations we need to cast INT inputs to INT64 and FLOAT inputs to FLOAT32
2384
2385
if inp_dtype in cls .dtype_map :
2385
2386
inp_cast = ctx .make_node ("Cast" , [node_inputs [0 ]], attr = {'to' : cls .dtype_map [inp_dtype ]}).output [0 ]
2386
2387
node_inputs [0 ] = inp_cast
2387
-
2388
+
2388
2389
new_node = ctx .make_node ("Unique" , node_inputs , name = node_name , attr = {'sorted' : 0 },
2389
- outputs = [utils .make_name ("y" ), utils .make_name ("idx_first" ), utils .make_name ("idx" ), utils .make_name ("counts" )])
2390
+ outputs = [utils .make_name ("y" ), utils .make_name ("idx_first" ),
2391
+ utils .make_name ("idx" ), utils .make_name ("counts" )])
2390
2392
ctx .replace_all_inputs (node_outputs [0 ], new_node .output [0 ])
2391
2393
ctx .replace_all_inputs (node_outputs [1 ], new_node .output [2 ])
2392
- if len (node_outputs )== 3 : # we need counts too (UniqueWithCounts)
2394
+ if len (node_outputs ) == 3 : # we need counts too (UniqueWithCounts)
2393
2395
ctx .replace_all_inputs (node_outputs [2 ], new_node .output [3 ])
2394
2396
if ctx .get_dtype (new_node .output [0 ]) != inp_dtype :
2395
- ctx .insert_new_node_on_output ("Cast" , new_node .output [0 ], name = utils . make_name ( node . name ) + "_cast" ,
2396
- to = inp_dtype )
2397
-
2397
+ ctx .insert_new_node_on_output ("Cast" , new_node .output [0 ], to = inp_dtype ,
2398
+ name = utils . make_name ( node . name ) + "_cast" )
2399
+
2398
2400
# cast idx and counts if needed
2399
2401
out_dtype = node .get_attr_value ('out_idx' )
2400
2402
if out_dtype != TensorProto .INT64 :
0 commit comments