Skip to content

Commit f759913

Browse files
glenn-jochertdhooghe
authored andcommitted
Add TFDWConv() depth_multiplier (ultralytics#7858)
Enabled grouped non c1 == c2 convolutions in TF YOLOv5 models.
1 parent dd44ef6 commit f759913

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

models/tf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ class TFDWConv(keras.layers.Layer):
9191
def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None):
9292
# ch_in, ch_out, weights, kernel, stride, padding, groups
9393
super().__init__()
94-
assert c1 == c2, f'TFDWConv() input={c1} must equal output={c2} channels'
94+
assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels'
9595
conv = keras.layers.DepthwiseConv2D(
9696
kernel_size=k,
97+
depth_multiplier=c2 // c1,
9798
strides=s,
9899
padding='SAME' if s == 1 else 'VALID',
99100
use_bias=not hasattr(w, 'bn'),

0 commit comments

Comments
 (0)