-
Notifications
You must be signed in to change notification settings - Fork 451
Fix unsupported ops TF 2.14.0: OnesLike #2270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4105,6 +4105,54 @@ def func(x, y): | |
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False)) | ||
|
||
@check_tf_max_version("2.13.1") | ||
@check_opset_min_version(9, "ConstantOfShape") | ||
@check_opset_max_version(9, "ConstantOfShape") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we already have 2 functions for different versions, we don't need to add an opset limitation for these test methods. And we don't need to create different test cases for different versions. You just need to create one test case for tf op OnesLike, and the test framework will help on testing different ops. At least, please remove the check_opset_max_version so we don't limit it running on higher opset version which supports ConstantOfShape. |
||
def test_ones_like_opset9(self): | ||
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32) | ||
input_y = np.array([16, 16, 3]).astype(np.int64) | ||
|
||
def func(x, y): | ||
z = tf.reshape(x, y) | ||
return tf.ones_like(z, name=_TFOUTPUT) | ||
|
||
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False)) | ||
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False)) | ||
|
||
@check_tf_max_version("2.13.1") | ||
@check_opset_min_version(11, "Expand") | ||
def test_ones_like_opset11(self): | ||
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32) | ||
input_y = np.array([16, 16, 3]).astype(np.int64) | ||
|
||
def func(x, y): | ||
z = tf.reshape(x, y) | ||
return tf.ones_like(z, name=_TFOUTPUT) | ||
|
||
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "Expand", 1, disabled=False)) | ||
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "Expand", 1, disabled=False)) | ||
|
||
@check_tf_min_version("2.14.0") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to require 2.14.0 version to run this test case? |
||
@check_opset_min_version(9, "ConstantOfShape") | ||
@skip_tflite("tflite does convert OnesLike correctly") | ||
@skip_tfjs("tfjs does convert OnesLike correctly") | ||
def test_ones_like_tf2_14(self): | ||
input_x = np.random.random_sample([3, 16, 16]).astype(np.float32) | ||
input_y = np.array([16, 16, 3]).astype(np.int64) | ||
|
||
def func(x, y): | ||
z = tf.reshape(x, y) | ||
return tf.ones_like(z, name=_TFOUTPUT) | ||
|
||
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x, _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False)) | ||
self._run_test_case(func, [_OUTPUT], {_INPUT: input_x.astype(np.int32), _INPUT1: input_y}, as_session=True, | ||
graph_validator=lambda g: check_op_count(g, "ConstantOfShape", 1, disabled=False)) | ||
|
||
@check_opset_min_version(9, "is_nan") | ||
def test_isnan(self): | ||
# only compatible with dtype `float32` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means this test case won't be run on any versions above 2.13.1. Could you please share why we can't do this? Or we can remove this limitation?