From e6b5d4206506498347f1bdb3d106dfa6dab85d2d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 19 May 2022 00:34:53 +0200 Subject: [PATCH 1/8] Add DWConvTranspose2d() module --- models/common.py | 6 ++++++ models/tf.py | 32 ++++++++++++++++++++++++++------ models/yolo.py | 2 +- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/models/common.py b/models/common.py index 0c028352abac..abee3591a9f8 100644 --- a/models/common.py +++ b/models/common.py @@ -56,6 +56,12 @@ def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act) +class DWConvTranspose2d(nn.ConvTranspose2d): + # Depth-wise transpose convolution class + def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0): # ch_in, ch_out, kernel, stride, padding, padding_out + super().__init__(c1, c2, k, s, p1, p2, groups=math.gcd(c1, c2)) + + class TransformerLayer(nn.Module): # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance) def __init__(self, c, num_heads): diff --git a/models/tf.py b/models/tf.py index 6efc87fdd774..246ebc35e050 100644 --- a/models/tf.py +++ b/models/tf.py @@ -108,6 +108,27 @@ def call(self, inputs): return self.act(self.bn(self.conv(inputs))) +class TFDWConvTranspose2d(keras.layers.Layer): + # Depthwise ConvTranspose2d + def __init__(self, c1, c2, k, s, p1, p2, w=None): + # ch_in, ch_out, weights, kernel, stride, padding, groups + super().__init__() + assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' + assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' + self.conv = tf.concat([keras.layers.ConvTranspose2d( + filters=1, + kernel_size=k, + strides=s, + padding='VALID', + output_padding=p2, + use_bias=True, + kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()[:, i:i + 1]), + bias_initializer=keras.initializers.Constant(w.conv.bias.numpy())) for i in range(c2)], 3) + + def call(self, inputs): + return self.conv(inputs)[:, 1:-1, 1:-1] + + class TFFocus(keras.layers.Layer): # Focus wh information into c-space def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True, w=None): @@ -153,14 +174,13 @@ def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): super().__init__() assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" self.conv = keras.layers.Conv2D( - c2, - k, - s, - 'VALID', + filters=c2, + kernel_size=k, + strides=s, + padding='VALID', use_bias=bias, kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()), - bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None, - ) + bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None) def call(self, inputs): return self.conv(inputs) diff --git a/models/yolo.py b/models/yolo.py index 9695ed7ff186..3659bea6c17d 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -266,7 +266,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3x): + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x): c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) From b4873bd74aea47999a8d7ee445f07217cb545075 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 May 2022 22:35:41 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/tf.py | 36 +++++++++++++++++++----------------- models/yolo.py | 2 +- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/models/tf.py b/models/tf.py index 246ebc35e050..2356a9716300 100644 --- a/models/tf.py +++ b/models/tf.py @@ -115,15 +115,17 @@ def __init__(self, c1, c2, k, s, p1, p2, w=None): super().__init__() assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' - self.conv = tf.concat([keras.layers.ConvTranspose2d( - filters=1, - kernel_size=k, - strides=s, - padding='VALID', - output_padding=p2, - use_bias=True, - kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()[:, i:i + 1]), - bias_initializer=keras.initializers.Constant(w.conv.bias.numpy())) for i in range(c2)], 3) + self.conv = tf.concat([ + keras.layers.ConvTranspose2d(filters=1, + kernel_size=k, + strides=s, + padding='VALID', + output_padding=p2, + use_bias=True, + kernel_initializer=keras.initializers.Constant( + w.conv.weight.permute(2, 3, 1, 0).numpy()[:, i:i + 1]), + bias_initializer=keras.initializers.Constant(w.conv.bias.numpy())) + for i in range(c2)], 3) def call(self, inputs): return self.conv(inputs)[:, 1:-1, 1:-1] @@ -173,14 +175,14 @@ class TFConv2d(keras.layers.Layer): def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): super().__init__() assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" - self.conv = keras.layers.Conv2D( - filters=c2, - kernel_size=k, - strides=s, - padding='VALID', - use_bias=bias, - kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()), - bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None) + self.conv = keras.layers.Conv2D(filters=c2, + kernel_size=k, + strides=s, + padding='VALID', + use_bias=bias, + kernel_initializer=keras.initializers.Constant( + w.weight.permute(2, 3, 1, 0).numpy()), + bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None) def call(self, inputs): return self.conv(inputs) diff --git a/models/yolo.py b/models/yolo.py index 3659bea6c17d..c7674a57c1d2 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -266,7 +266,7 @@ def parse_model(d, ch): # model_dict, input_channels(3) n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x): + BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x): c1, c2 = ch[f], args[0] if c2 != no: # if not output c2 = make_divisible(c2 * gw, 8) From 4cd90a50c0ef306d4dd48913fd13c1793dfd187c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 19 May 2022 00:50:05 +0200 Subject: [PATCH 3/8] Add DWConvTranspose2d() module --- models/tf.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/models/tf.py b/models/tf.py index 246ebc35e050..f4b1b60c9d58 100644 --- a/models/tf.py +++ b/models/tf.py @@ -27,7 +27,8 @@ import torch.nn as nn from tensorflow import keras -from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, Focus, autopad +from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, \ + DWConvTranspose2d, Focus, autopad from models.experimental import MixConv2d, attempt_load from models.yolo import Detect from utils.activations import SiLU @@ -110,23 +111,23 @@ def call(self, inputs): class TFDWConvTranspose2d(keras.layers.Layer): # Depthwise ConvTranspose2d - def __init__(self, c1, c2, k, s, p1, p2, w=None): + def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): # ch_in, ch_out, weights, kernel, stride, padding, groups super().__init__() assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' - self.conv = tf.concat([keras.layers.ConvTranspose2d( + self.conv = [keras.layers.Conv2DTranspose( filters=1, kernel_size=k, strides=s, padding='VALID', output_padding=p2, use_bias=True, - kernel_initializer=keras.initializers.Constant(w.conv.weight.permute(2, 3, 1, 0).numpy()[:, i:i + 1]), - bias_initializer=keras.initializers.Constant(w.conv.bias.numpy())) for i in range(c2)], 3) + kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()[..., i:i + 1]), + bias_initializer=keras.initializers.Constant(w.bias.numpy()[i])) for i in range(c2)] def call(self, inputs): - return self.conv(inputs)[:, 1:-1, 1:-1] + return tf.concat(self.conv(inputs), 3)[:, 1:-1, 1:-1] class TFFocus(keras.layers.Layer): @@ -360,7 +361,8 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3, C3x]: + if m in [nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv, + BottleneckCSP, C3, C3x]: c1, c2 = ch[f], args[0] c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 From 34b2c0c3d26d7a7d9197e00f004117547c962db5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 May 2022 22:50:54 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/tf.py | 45 ++++++++++++++++++++++++--------------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/models/tf.py b/models/tf.py index f4b1b60c9d58..8722d9273456 100644 --- a/models/tf.py +++ b/models/tf.py @@ -27,8 +27,8 @@ import torch.nn as nn from tensorflow import keras -from models.common import C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, \ - DWConvTranspose2d, Focus, autopad +from models.common import (C3, SPP, SPPF, Bottleneck, BottleneckCSP, C3x, Concat, Conv, CrossConv, DWConv, + DWConvTranspose2d, Focus, autopad) from models.experimental import MixConv2d, attempt_load from models.yolo import Detect from utils.activations import SiLU @@ -116,15 +116,17 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): super().__init__() assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' - self.conv = [keras.layers.Conv2DTranspose( - filters=1, - kernel_size=k, - strides=s, - padding='VALID', - output_padding=p2, - use_bias=True, - kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()[..., i:i + 1]), - bias_initializer=keras.initializers.Constant(w.bias.numpy()[i])) for i in range(c2)] + self.conv = [ + keras.layers.Conv2DTranspose(filters=1, + kernel_size=k, + strides=s, + padding='VALID', + output_padding=p2, + use_bias=True, + kernel_initializer=keras.initializers.Constant( + w.weight.permute(2, 3, 1, 0).numpy()[..., i:i + 1]), + bias_initializer=keras.initializers.Constant(w.bias.numpy()[i])) + for i in range(c2)] def call(self, inputs): return tf.concat(self.conv(inputs), 3)[:, 1:-1, 1:-1] @@ -174,14 +176,14 @@ class TFConv2d(keras.layers.Layer): def __init__(self, c1, c2, k, s=1, g=1, bias=True, w=None): super().__init__() assert g == 1, "TF v2.2 Conv2D does not support 'groups' argument" - self.conv = keras.layers.Conv2D( - filters=c2, - kernel_size=k, - strides=s, - padding='VALID', - use_bias=bias, - kernel_initializer=keras.initializers.Constant(w.weight.permute(2, 3, 1, 0).numpy()), - bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None) + self.conv = keras.layers.Conv2D(filters=c2, + kernel_size=k, + strides=s, + padding='VALID', + use_bias=bias, + kernel_initializer=keras.initializers.Constant( + w.weight.permute(2, 3, 1, 0).numpy()), + bias_initializer=keras.initializers.Constant(w.bias.numpy()) if bias else None) def call(self, inputs): return self.conv(inputs) @@ -361,8 +363,9 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) pass n = max(round(n * gd), 1) if n > 1 else n # depth gain - if m in [nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv, - BottleneckCSP, C3, C3x]: + if m in [ + nn.Conv2d, Conv, DWConv, DWConvTranspose2d, Bottleneck, SPP, SPPF, MixConv2d, Focus, CrossConv, + BottleneckCSP, C3, C3x]: c1, c2 = ch[f], args[0] c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 From a3be910abb3dd3a78e0e89df56b02db2c91d822d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 20 May 2022 11:57:38 +0200 Subject: [PATCH 5/8] Fix --- models/tf.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/models/tf.py b/models/tf.py index 8722d9273456..11d7bbb3714f 100644 --- a/models/tf.py +++ b/models/tf.py @@ -116,20 +116,19 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): super().__init__() assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' - self.conv = [ - keras.layers.Conv2DTranspose(filters=1, - kernel_size=k, - strides=s, - padding='VALID', - output_padding=p2, - use_bias=True, - kernel_initializer=keras.initializers.Constant( - w.weight.permute(2, 3, 1, 0).numpy()[..., i:i + 1]), - bias_initializer=keras.initializers.Constant(w.bias.numpy()[i])) - for i in range(c2)] + weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy() + self.conv = [keras.layers.Conv2DTranspose(filters=1, + kernel_size=k, + strides=s, + padding='VALID', + output_padding=p2, + use_bias=True, + kernel_initializer=keras.initializers.Constant(weight), + bias_initializer=keras.initializers.Constant(bias[i])) + for i in range(c2)] def call(self, inputs): - return tf.concat(self.conv(inputs), 3)[:, 1:-1, 1:-1] + return tf.concat([x(inputs) for x in self.conv], 3)[:, 1:-1, 1:-1] class TFFocus(keras.layers.Layer): From 7dead7ebb7d53101d841cc8a61f9009e65552569 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 May 2022 09:58:12 +0000 Subject: [PATCH 6/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- models/tf.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/models/tf.py b/models/tf.py index 11d7bbb3714f..4dcc5f17ddc9 100644 --- a/models/tf.py +++ b/models/tf.py @@ -117,15 +117,15 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy() - self.conv = [keras.layers.Conv2DTranspose(filters=1, - kernel_size=k, - strides=s, - padding='VALID', - output_padding=p2, - use_bias=True, - kernel_initializer=keras.initializers.Constant(weight), - bias_initializer=keras.initializers.Constant(bias[i])) - for i in range(c2)] + self.conv = [ + keras.layers.Conv2DTranspose(filters=1, + kernel_size=k, + strides=s, + padding='VALID', + output_padding=p2, + use_bias=True, + kernel_initializer=keras.initializers.Constant(weight), + bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c2)] def call(self, inputs): return tf.concat([x(inputs) for x in self.conv], 3)[:, 1:-1, 1:-1] From 1a3af0bfa251dabcd3a9fd55c7aad4d64f2300b4 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 20 May 2022 16:08:22 +0200 Subject: [PATCH 7/8] Fix --- models/tf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/models/tf.py b/models/tf.py index 4dcc5f17ddc9..a50b47fa1a78 100644 --- a/models/tf.py +++ b/models/tf.py @@ -117,6 +117,7 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy() + self.ch = c1 self.conv = [ keras.layers.Conv2DTranspose(filters=1, kernel_size=k, @@ -124,11 +125,11 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): padding='VALID', output_padding=p2, use_bias=True, - kernel_initializer=keras.initializers.Constant(weight), - bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c2)] + kernel_initializer=keras.initializers.Constant(weight[..., i:i + 1]), + bias_initializer=keras.initializers.Constant(bias[i])) for i in range(c1)] def call(self, inputs): - return tf.concat([x(inputs) for x in self.conv], 3)[:, 1:-1, 1:-1] + return tf.concat([m(x) for m, x in zip(self.conv, tf.split(inputs, self.c1, 3))], 3)[:, 1:-1, 1:-1] class TFFocus(keras.layers.Layer): From b7cf76a3828e2ea908688ecf27c48b7fa3b26c40 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 20 May 2022 16:10:37 +0200 Subject: [PATCH 8/8] Fix --- models/tf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/tf.py b/models/tf.py index a50b47fa1a78..202a957e3e63 100644 --- a/models/tf.py +++ b/models/tf.py @@ -117,7 +117,7 @@ def __init__(self, c1, c2, k=1, s=1, p1=0, p2=0, w=None): assert c1 == c2, f'TFDWConv() output={c2} must be equal to input={c1} channels' assert k == 4 and p1 == 1, 'TFDWConv() only valid for k=4 and p1=1' weight, bias = w.weight.permute(2, 3, 1, 0).numpy(), w.bias.numpy() - self.ch = c1 + self.c1 = c1 self.conv = [ keras.layers.Conv2DTranspose(filters=1, kernel_size=k,