Skip to content

Commit 4d59f65

Browse files
authored
New TensorFlow TFDWConv() module (#7824)
* New TensorFlow `TFDWConv()` module Analog to DWConv() module: https://github.com/ultralytics/yolov5/blob/8aa2085a7e7ae20a17a7548edefbdb2960f2b29b/models/common.py#L53-L57 * Fix and new activations() function * Update tf.py
1 parent 1e112ce commit 4d59f65

File tree

1 file changed

+41
-18
lines changed

1 file changed

+41
-18
lines changed

models/tf.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,38 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
7070
# see https://stackoverflow.com/questions/52975843/comparing-conv2d-with-padding-between-tensorflow-and-pytorch
7171

7272
conv = keras.layers.Conv2D(
73-
c2,
74-
k,
75-
s,
76-
'SAME' if s == 1 else 'VALID',
77-
use_bias=False if hasattr(w, 'bn') else True,
73+
filters=c2,
74+
kernel_size=k,
75+
strides=s,
76+
padding='SAME' if s == 1 else 'VALID',
77+
use_bias=not hasattr(w, 'bn'),
7878
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
7979
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
8080
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
8181
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
82+
self.act = activations(w.act) if act else tf.identity
83+
84+
def call(self, inputs):
85+
return self.act(self.bn(self.conv(inputs)))
8286

83-
# YOLOv5 activations
84-
if isinstance(w.act, nn.LeakyReLU):
85-
self.act = (lambda x: keras.activations.relu(x, alpha=0.1)) if act else tf.identity
86-
elif isinstance(w.act, nn.Hardswish):
87-
self.act = (lambda x: x * tf.nn.relu6(x + 3) * 0.166666667) if act else tf.identity
88-
elif isinstance(w.act, (nn.SiLU, SiLU)):
89-
self.act = (lambda x: keras.activations.swish(x)) if act else tf.identity
90-
else:
91-
raise Exception(f'no matching TensorFlow activation found for {w.act}')
87+
88+
class TFDWConv(keras.layers.Layer):
89+
# Depthwise convolution
90+
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
91+
# ch_in, ch_out, weights, kernel, stride, padding, groups
92+
super().__init__()
93+
assert isinstance(k, int), "Convolution with multiple kernels are not allowed."
94+
95+
conv = keras.layers.DepthwiseConv2D(
96+
kernel_size=k,
97+
strides=s,
98+
padding='SAME' if s == 1 else 'VALID',
99+
use_bias=not hasattr(w, 'bn'),
100+
kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()),
101+
bias_initializer='zeros' if hasattr(w, 'bn') else keras.initializers.Constant(w.conv.bias.numpy()))
102+
self.conv = conv if s == 1 else keras.Sequential([TFPad(autopad(k, p)), conv])
103+
self.bn = TFBN(w.bn) if hasattr(w, 'bn') else tf.identity
104+
self.act = activations(w.act) if act else tf.identity
92105

93106
def call(self, inputs):
94107
return self.act(self.bn(self.conv(inputs)))
@@ -103,10 +116,8 @@ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
103116

104117
def call(self, inputs): # x(b,w,h,c) -> y(b,w/2,h/2,4c)
105118
# inputs = inputs / 255 # normalize 0-255 to 0-1
106-
return self.conv(
107-
tf.concat(
108-
[inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]],
109-
3))
119+
inputs = [inputs[:, ::2, ::2, :], inputs[:, 1::2, ::2, :], inputs[:, ::2, 1::2, :], inputs[:, 1::2, 1::2, :]]
120+
return self.conv(tf.concat(inputs, 3))
110121

111122

112123
class TFBottleneck(keras.layers.Layer):
@@ -439,6 +450,18 @@ def _nms(x, topk_all=100, iou_thres=0.45, conf_thres=0.25): # agnostic NMS
439450
return padded_boxes, padded_scores, padded_classes, valid_detections
440451

441452

453+
def activations(act=nn.SiLU):
454+
# Returns TF activation from input PyTorch activation
455+
if isinstance(act, nn.LeakyReLU):
456+
return lambda x: keras.activations.relu(x, alpha=0.1)
457+
elif isinstance(act, nn.Hardswish):
458+
return lambda x: x * tf.nn.relu6(x + 3) * 0.166666667
459+
elif isinstance(act, (nn.SiLU, SiLU)):
460+
return lambda x: keras.activations.swish(x)
461+
else:
462+
raise Exception(f'no matching TensorFlow activation found for PyTorch activation {act}')
463+
464+
442465
def representative_dataset_gen(dataset, ncalib=100):
443466
# Representative dataset generator for use with converter.representative_dataset, returns a generator of np arrays
444467
for n, (path, img, im0s, vid_cap, string) in enumerate(dataset):

0 commit comments

Comments
 (0)