Skip to content

Commit f4be40d

Browse files
glenn-jocherpre-commit-ci[bot]
authored andcommitted
Add DWConvTranspose2d() module (ultralytics#7881)
* Add DWConvTranspose2d() module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add DWConvTranspose2d() module * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix * Fix Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3d28a0f commit f4be40d

File tree

3 files changed

+43
-12
lines changed

3 files changed

+43
-12
lines changed

models/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride
5656
super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
5757

5858

59+
class DWConvTranspose2d(nn.ConvTranspose2d):
60+
# Depth-wise transpose convolution class
61+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out
62+
super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2))
63+
64+
5965
class TransformerLayer(nn.Module):
6066
# Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
6167
def __init__(self, c, num_heads):

models/tf.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
import torch.nn as nn
2828
from tensorflow import keras
2929

30-
from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad
30+
from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv,
31+
DWConvTranspose2d, Focus, autopad)
3132
from models.experimental import MixConv2d, attempt_load
3233
from models.yolo import Detect
3334
from utils.activations import SiLU
@@ -108,6 +109,29 @@ def call(self, inputs):
108109
return self.act(self.bn(self.conv(inputs)))
109110

110111

112+
class TFDWConvTranspose2d(keras.layers.Layer):
113+
# Depthwise ConvTranspose2d
114+
def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None):
115+
# ch_in, ch_out, weights, kernel, stride, padding, groups
116+
super().__init__()
117+
assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels'
118+
assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1'
119+
weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy()
120+
self.c1 = c1
121+
self.conv = [
122+
keras.layers.Conv2DTranspose(filters=1,
123+
kernel_size=k,
124+
strides=s,
125+
padding='VALID',
126+
output_padding=p2,
127+
use_bias=True,
128+
kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]),
129+
bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)]
130+
131+
def call(self, inputs):
132+
return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1]
133+
134+
111135
class TFFocus(keras.layers.Layer):
112136
# Focus wh information into c-space
113137
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None):
@@ -152,15 +176,14 @@ class TFConv2d(keras.layers.Layer):
152176
def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None):
153177
super().__init__()
154178
assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument"
155-
self.conv = keras.layers.Conv2D(
156-
c2,
157-
k,
158-
s,
159-
'VALID',
160-
use_bias=bias,
161-
kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()),
162-
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None,
163-
)
179+
self.conv = keras.layers.Conv2D(filters=c2,
180+
kernel_size=k,
181+
strides=s,
182+
padding='VALID',
183+
use_bias=bias,
184+
kernel_initializer=keras.initializers.Constant(
185+
w.weight.permute(2, 3, 1, 0).numpy()),
186+
bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None)
164187

165188
def call(self, inputs):
166189
return self.conv(inputs)
@@ -340,7 +363,9 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
340363
pass
341364

342365
n = max(round(n * gd), 1) if n > 1 else n # depth gain
343-
if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]:
366+
if m in [
367+
nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv,
368+
BottleneckCSP, C3, C3x]:
344369
c1, c2 = ch[f], args[0]
345370
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
346371

models/yolo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def parse_model(d, ch): # model_dict, input_channels(3)
266266

267267
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
268268
if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
269-
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x):
269+
BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x):
270270
c1, c2 = ch[f], args[0]
271271
if c2 != no: # if not output
272272
c2 = make_divisible(c2 * gw, 8)

0 commit comments

Comments
 (0)