From 03616069ba3d55b72dd5c5948ff6724a58bcd7a4 Mon Sep 17 00:00:00 2001 From: prototype <295484914@qq.com> Date: Sat, 11 May 2019 15:30:55 +0800 Subject: [PATCH 1/4] fix onnx frontend flatten bug --- python/tvm/relay/frontend/onnx.py | 19 ++++++++++++++++- tests/python/frontend/onnx/test_forward.py | 24 ++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index b4d36306c85d..73df00c8c746 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -335,6 +335,23 @@ class Reciprocal(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _expr.const(1.0) / inputs[0] + +class Flatten(OnnxOpConverter): + """ Operator converter for Flatten. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + axis = attr.get('axis', 1) + + if axis == 1: + return _op.nn.batch_flatten(inputs[0]) + else: + newshape = [0]*(axis+1) + newshape[axis] = -1; + return _op.reshape(inputs[0], list(newshape)) + + class Reshape(OnnxOpConverter): """ Operator converter for Reshape. """ @@ -850,7 +867,7 @@ def _get_convert_map(opset): # 'InstanceNormalization' # 'LpNormalization' 'Dropout': AttrCvt('dropout', {'ratio': 'rate'}, ignores=['is_test']), - 'Flatten': Renamer('batch_flatten'), + 'Flatten': Flatten.get_converter(opset), 'LRN': LRN.get_converter(opset), # defs/reduction diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 7be6bb611e9a..f867e73e8c08 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -211,6 +211,29 @@ def test_squeeze(): tvm.testing.assert_allclose(out_shape, tvm_out.shape) +def test_flatten(): + + in_shape = (1, 3, 4, 4) + axis = 1 + ref_shape = (1, 48) + + flatten_node = helper.make_node("Flatten", ["in"], ["out"], axis = axis) + + graph = helper.make_graph([flatten_node], + "flatten_test", + inputs = [helper.make_tensor_value_info("in", + TensorProto.FLOAT, list(in_shape))], + outputs = [helper.make_tensor_value_info("out", + TensorProto.FLOAT, list(ref_shape))]) + + model = helper.make_model(graph, producer_name='flatten_test') + + for target, ctx in ctx_list(): + x = np.random.uniform(size=in_shape).astype('int32') + tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32') + + tvm.testing.assert_allclose(ref_shape, tvm_out.shape) + def test_unsqueeze(): in_shape = (3, 3) axis = (0, 3, 4) @@ -1046,6 +1069,7 @@ def test_LogSoftmax(): {'axis': 1}) if __name__ == '__main__': + test_flatten() test_reshape() test_shape() test_power() From 2dc867b1efdbae5167b254e57d290924f38a6721 Mon Sep 17 00:00:00 2001 From: Oldpan <295484914@qq.com> Date: Sat, 11 May 2019 17:35:28 +0800 Subject: [PATCH 2/4] Update onnx.py --- python/tvm/relay/frontend/onnx.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 73df00c8c746..5c730e25a497 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -343,13 +343,14 @@ class Flatten(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get('axis', 1) - if axis == 1: - return _op.nn.batch_flatten(inputs[0]) + out = _op.nn.batch_flatten(inputs[0]) else: newshape = [0]*(axis+1) - newshape[axis] = -1; - return _op.reshape(inputs[0], list(newshape)) + newshape[axis] = -1 + out = _op.reshape(inputs[0], list(newshape)) + + return out class Reshape(OnnxOpConverter): From 910e3a2121236d0c0f1b183374679e2f1d6c1839 Mon Sep 17 00:00:00 2001 From: Oldpan <295484914@qq.com> Date: Sat, 11 May 2019 17:37:17 +0800 Subject: [PATCH 3/4] Update onnx.py --- python/tvm/relay/frontend/onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 5c730e25a497..813946a28db4 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -349,7 +349,6 @@ def _impl_v1(cls, inputs, attr, params): newshape = [0]*(axis+1) newshape[axis] = -1 out = _op.reshape(inputs[0], list(newshape)) - return out From 0123e265c867c08382cd068bcd1b1500c25a4a96 Mon Sep 17 00:00:00 2001 From: Oldpan <295484914@qq.com> Date: Mon, 13 May 2019 12:29:22 +0800 Subject: [PATCH 4/4] Update onnx.py --- python/tvm/relay/frontend/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 813946a28db4..eba02e70c865 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -346,7 +346,7 @@ def _impl_v1(cls, inputs, attr, params): if axis == 1: out = _op.nn.batch_flatten(inputs[0]) else: - newshape = [0]*(axis+1) + newshape = [0] * (axis + 1) newshape[axis] = -1 out = _op.reshape(inputs[0], list(newshape)) return out