diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 1946223a50a4..6dd164c6e35e 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -3851,11 +3851,11 @@ def _convert_operator( @staticmethod def _set_span(sym, node_name): span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) - if isinstance(sym, _expr.Call): + if isinstance(sym, _expr.Call) and sym.span is None: sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) elif isinstance(sym, _expr.TupleWrapper): tuple_value = sym.tuple_value - if isinstance(tuple_value, _expr.Call): + if isinstance(tuple_value, _expr.Call) and tuple_value.span is None: tuple_value = _expr.Call( tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span ) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 26e9476d15c7..8446ef3d590b 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -48,6 +48,7 @@ from tvm import relay import tvm.relay.testing.tf as tf_testing from tvm.runtime.vm import VirtualMachine +from tvm.relay.frontend.tensorflow import from_tensorflow from packaging import version as package_version import tvm.testing @@ -5451,5 +5452,37 @@ def test_forward_unique_with_counts(): _test_unique_with_counts(20, dtype, is_dyn) +####################################################################### +# check graph ir for nn.moments +# ------------ + + +def test_moments(): + g = tf.Graph() + shape = [4, 176, 8, 8] + dtype = "float32" + with g.as_default(): + A = tf.placeholder(shape=shape, dtype=dtype, name="A") + B = tf.placeholder(shape=shape, dtype=dtype, name="B") + mean, variance = tf.nn.moments(A, [1], keep_dims=True) + normalised_input = (A - mean) / tf.sqrt(variance + 0.0005) + + mod, _ = from_tensorflow(g.as_graph_def(add_shapes=True)) + program = """ + def @main(%A: Tensor[(4, 176, 8, 8), float32]) { + %527 = mean(%A, axis=[1], keepdims=True) /* moments/mean */; + %528 = subtract(%A, %527) /* sub */; + %529 = subtract(%A, %527); + %530 = multiply(%529, %529) /* moments/SquaredDifference */; + %531 = mean(%530, axis=[1], keepdims=True) /* moments/variance */; + %532 = add(%531, 0.0005f) /* add */; + %533 = sqrt(%532) /* Sqrt */; + divide(%528, %533) /* truediv */ + } + """ + mod_golden = tvm.parser.parse('#[version = "0.0.5"]\n' + program) + tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True) + + if __name__ == "__main__": pytest.main([__file__])