diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py index 68bf767557d5..e08359066c2a 100644 --- a/python/tvm/relay/frontend/caffe.py +++ b/python/tvm/relay/frontend/caffe.py @@ -515,21 +515,76 @@ def convert_deconv(self, op): if weight: kh, kw = params["kernel_size"] weight_shape = [-1, conv_params.num_output, kh, kw] - weight_value = np.asarray(weight.data, np.float32) + if not weight.data: + if conv_params.weight_filler: + _filler = conv_params.weight_filler.value + weight_value = np.full(weight.shape.dim, _filler, np.float32) + else: + raise tvm.error.OpAttributeInvalid("At least weight_filler must be given") + else: + weight_value = np.asarray(weight.data, np.float32) weight_value = np.reshape(weight_value, weight_shape) # weight shape is in relay's IOHW format rn, we need it to be OIHW weight_value = np.transpose(weight_value, [1, 0, 2, 3]) else: - raise Exception("No weight value of layer {} in caffemodel".format(op.name)) + raise tvm.error.OpAttributeRequired( + "No weight value of layer {} in caffemodel".format(op.name) + ) weight_expr = self.exp_tab.new_const(weight_value, dtype="float32") in_expr = self.exp_tab.get_expr(inputs[0]) - out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) + + groups = params["groups"] + channels = params["channels"] + if bias: bias_value = np.asarray(bias.data, np.float32) bias_expr = self.exp_tab.new_const(bias_value, dtype="float32") - out = _op.nn.bias_add(out, bias_expr) + + if groups > channels: + raise tvm.error.OpAttributeInvalid( + "Groups cannot be larger than the number of input channels" + ) + + if groups == channels: + inputs_expr = _op.split(in_expr, groups, axis=1) + # changing split axis to 0, according to PR #9336 + weights_expr = _op.split(weight_expr, groups, axis=0) + # Preventing to create Concat layer with too many tensors(> 16) + q = groups >> 4 + r = groups % 16 + + params["groups"] = 1 + params["channels"] = 1 + out = [] + for lc in range(q): + _outputs = [] + _inputs = [inputs_expr[i] for i in range(lc << 4, (lc << 4) + 16)] + _weights = [weights_expr[i] for i in range(lc << 4, (lc << 4) + 16)] + for (i, w) in zip(_inputs, _weights): + _out = _op.nn.conv2d_transpose(data=i, weight=w, **params) + if bias: + _out = _op.nn.bias_add(_out, bias_expr) + _outputs.append(_out) + out.append(_op.concatenate(_outputs, axis=1)) + if r != 0: + _outputs = [] + _inputs = [inputs_expr[i] for i in range(groups - r, groups)] + _weights = [weights_expr[i] for i in range(groups - r, groups)] + for (i, w) in zip(_inputs, _weights): + _out = _op.nn.conv2d_transpose(data=i, weight=w, **params) + if bias: + _out = _op.nn.bias_add(_out, bias_expr) + _outputs.append(_out) + out.append(_op.concatenate(_outputs, axis=1)) + out = _op.concatenate(out, axis=1) + elif groups == 1: + out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params) + if bias: + out = _op.nn.bias_add(out, bias_expr) + else: + raise tvm.error.OpAttributeInvalid("Unable to handle.") return out def convert_slice(self, op): diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py index 0027a6b41736..59186e0fd273 100644 --- a/tests/python/frontend/caffe/test_forward.py +++ b/tests/python/frontend/caffe/test_forward.py @@ -35,6 +35,7 @@ from caffe.proto import caffe_pb2 as pb import tvm +import tvm.testing from tvm import relay from tvm.contrib import utils, graph_executor from tvm.contrib.download import download_testdata @@ -451,6 +452,35 @@ def test_forward_Deconvolution(): bias_filler=dict(type="xavier"), ), ) + _test_deconvolution( + data, + convolution_param=dict( + num_output=16, + bias_term=False, + pad=0, + kernel_size=2, + stride=2, + dilation=1, + group=16, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ), + ) + data = np.random.rand(1, 100, 32, 32).astype(np.float32) + _test_deconvolution( + data, + convolution_param=dict( + num_output=100, + bias_term=False, + pad=0, + kernel_size=2, + stride=2, + dilation=1, + group=100, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + ), + ) #######################################################################