Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4105,6 +4105,54 @@ def func(x, y):
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))

@check_tf_max_version("2.13.1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means this test case won't be run on any versions above 2.13.1. Could you please share why we can't do this? Or we can remove this limitation?

@check_opset_min_version(9, "ConstantOfShape")
@check_opset_max_version(9, "ConstantOfShape")
Copy link
Collaborator

@fatcat-z fatcat-z Nov 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we already have 2 functions for different versions, we don't need to add an opset limitation for these test methods. And we don't need to create different test cases for different versions.

You just need to create one test case for tf op OnesLike, and the test framework will help on testing different ops. At least, please remove the check_opset_max_version so we don't limit it running on higher opset version which supports ConstantOfShape.

def test_ones_like_opset9(self):
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32)
input_y = np.array([16, 16, 3]).astype(np.int64)

def func(x, y):
z = tf.reshape(x, y)
return tf.ones_like(z, name=_TFOUTPUT)

self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))

@check_tf_max_version("2.13.1")
@check_opset_min_version(11, "Expand")
def test_ones_like_opset11(self):
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32)
input_y = np.array([16, 16, 3]).astype(np.int64)

def func(x, y):
z = tf.reshape(x, y)
return tf.ones_like(z, name=_TFOUTPUT)

self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "Expand", 1, disabled=False))
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "Expand", 1, disabled=False))

@check_tf_min_version("2.14.0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to require 2.14.0 version to run this test case?

@check_opset_min_version(9, "ConstantOfShape")
@skip_tflite("tflite does convert OnesLike correctly")
@skip_tfjs("tfjs does convert OnesLike correctly")
def test_ones_like_tf2_14(self):
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32)
input_y = np.array([16, 16, 3]).astype(np.int64)

def func(x, y):
z = tf.reshape(x, y)
return tf.ones_like(z, name=_TFOUTPUT)

self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True,
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False))

@check_opset_min_version(9, "is_nan")
def test_isnan(self):
# only compatible with dtype `float32`
Expand Down
55 changes: 37 additions & 18 deletions tf2onnx/onnx_opset/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,31 +227,50 @@ def version_7(cls, ctx, node, **kwargs):
ctx.remove_input(node, node.input[1], 1)


def _const_like_version_1(ctx, node, value):
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
const_value = ctx.make_const(utils.make_name("value"), np.array(value).astype(np.int64))
mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_value.output[0]])
ctx.make_node("Cast", inputs=[mul_node.output[0]],
attr={'to': dtypes[0]},
name=node.name, outputs=node.output,
shapes=shapes, dtypes=dtypes)


def _const_like_version_9(ctx, node, value):
dtypes = node.output_dtypes
ctx.remove_node(node.name)
shape = ctx.make_node("Shape", node.input).output[0]
value_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[value])
ctx.make_node("ConstantOfShape", inputs=[shape],
attr={'value': value_tensor},
name=node.name, outputs=node.output,
dtypes=dtypes)


@tf_op("ZerosLike")
class ZerosLike:
@classmethod
def version_1(cls, ctx, node, **kwargs):
shapes = node.output_shapes
dtypes = node.output_dtypes
ctx.remove_node(node.name)
casted_input = ctx.make_node("Cast", node.input, attr={'to': onnx_pb.TensorProto.INT64})
const_zero = ctx.make_const(utils.make_name("zero"), np.array(0).astype(np.int64))
mul_node = ctx.make_node('Mul', inputs=[casted_input.output[0], const_zero.output[0]])
ctx.make_node("Cast", inputs=[mul_node.output[0]],
attr={'to': dtypes[0]},
name=node.name, outputs=node.output,
shapes=shapes, dtypes=dtypes)
_const_like_version_1(ctx, node, 0)

@classmethod
def version_9(cls, ctx, node, **kwargs):
dtypes = node.output_dtypes
ctx.remove_node(node.name)
shape = ctx.make_node("Shape", node.input).output[0]
zero_tensor = helper.make_tensor("value", dtypes[0], [1], vals=[0])
ctx.make_node("ConstantOfShape", inputs=[shape],
attr={'value': zero_tensor},
name=node.name, outputs=node.output,
dtypes=dtypes)
_const_like_version_9(ctx, node, 0)


@tf_op("OnesLike")
class OnesLike:
@classmethod
def version_1(cls, ctx, node, **kwargs):
_const_like_version_1(ctx, node, 1)

@classmethod
def version_9(cls, ctx, node, **kwargs):
_const_like_version_9(ctx, node, 1)


@tf_op(["IteratorV2", "FIFOQueueV2"])
Expand Down