diff --git a/models/tf.py b/models/tf.py index c638ff0084bc..6efc87fdd774 100644 --- a/models/tf.py +++ b/models/tf.py @@ -91,9 +91,10 @@ class TFDWConv(keras.layers.Layer): def __init__(self, c1, c2, k=1, s=1, p=None, act=True, w=None): # ch_in, ch_out, weights, kernel, stride, padding, groups super().__init__() - assert c1 == c2, f'TFDWConv() input={c1} must equal output={c2} channels' + assert c2 % c1 == 0, f'TFDWConv() output={c2} must be a multiple of input={c1} channels' conv = keras.layers.DepthwiseConv2D( kernel_size=k, + depth_multiplier=c2 // c1, strides=s, padding='SAME' if s == 1 else 'VALID', use_bias=not hasattr(w, 'bn'),