Skip to content
14 changes: 14 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6005,5 +6005,19 @@ def func(x):
x_val = make_xval([2, 3])
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_opset_min_version(11, "Pad")
def test_conv_unknown_kernel_channels(self):
x_shape = [2, 10, 3]
x_val = make_xval(x_shape)
kernel_shape = [4, 3, 5]
kernel_val = make_xval(kernel_shape)
pad_val = np.array([[0, 0], [0, 0], [0, 0]], np.int64)
def func(x, kernel, pad):
# Make kernel dimensions unknown
kernel = tf.pad(kernel, pad)
conv = tf.nn.conv1d(x, kernel, stride=[1], padding='VALID')
return tf.identity(conv, name='output')
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: kernel_val, _INPUT2: pad_val})

if __name__ == '__main__':
unittest_main()
14 changes: 9 additions & 5 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ def conv_kernel_shape(ctx, node, input_idx, spatial=2):
# Get spatial part.
kernel_shape = kernel_shape[:spatial]

# Set new value and return it.
node.set_attr("kernel_shape", kernel_shape)
# Set attribute value only if all dimensions are known.
if all(d > 1 for d in kernel_shape):
node.set_attr("kernel_shape", kernel_shape)

return kernel_shape

Expand Down Expand Up @@ -379,11 +380,13 @@ def any_version(cls, opset, ctx, node, **kwargs):
data_format = str(node.attr["data_format"].s, encoding="utf8")
shape_dim = -1
if data_format == "NHWC":
shape_dim = ctx.get_shape(node.input[0])[3]
shape_dim = ctx.get_shape(node.input[0])[-1]
elif data_format == "NCHW":
shape_dim = ctx.get_shape(node.input[0])[1]
if shape_dim != -1:
groups = int(shape_dim / ctx.get_shape(node.input[1])[2])
filter_in_channels = ctx.get_shape(node.input[1])[-2]
if filter_in_channels != -1:
groups = shape_dim // filter_in_channels

node.set_attr("group", groups)

Expand Down Expand Up @@ -649,7 +652,8 @@ def version_1(cls, ctx, node, **kwargs):
raise ValueError("input channel must be positive")
k_output_channels = k_input_channels * k_channel_multiplier

node.set_attr("kernel_shape", [k_h, k_w])
if k_h >= 0 and k_w >= 0:
node.set_attr("kernel_shape", [k_h, k_w])
strides = conv_dims_attr(node, "strides")
dilations = conv_dims_attr(node, "dilations")
node.set_attr("group", k_input_channels)
Expand Down