Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TF frontend][bugfix]Avoid making a new node when already has span info #7789

Merged
merged 9 commits into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3851,11 +3851,11 @@ def _convert_operator(
@staticmethod
def _set_span(sym, node_name):
span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
if isinstance(sym, _expr.Call):
if isinstance(sym, _expr.Call) and sym.span is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a test case for these two lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd love to, but I don't know how to create unitest for pb parsering flow, any suggestion? thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this require pb parsing or directly creating a tf program would be sufficient?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this require pb parsing or directly creating a tf program would be sufficient?

You are right,where can I find an exmple? I know the basic idea is creatig a tf graph, using from_tensorflow to handle it, then check the output graph ir.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refer to the existing unit/integration tests for tf frontend. Most of them just create a tf graph.

sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
elif isinstance(sym, _expr.TupleWrapper):
tuple_value = sym.tuple_value
if isinstance(tuple_value, _expr.Call):
if isinstance(tuple_value, _expr.Call) and tuple_value.span is None:
tuple_value = _expr.Call(
tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span
)
Expand Down
39 changes: 39 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5451,5 +5451,44 @@ def test_forward_unique_with_counts():
_test_unique_with_counts(20, dtype, is_dyn)


#######################################################################
# check graph ir for nn.moments
# ------------

SEMVER = '#[version = "0.0.5"]\n'


def run_from_tensorflow(graph):
xqdan marked this conversation as resolved.
Show resolved Hide resolved
mod, _ = from_tensorflow(graph.as_graph_def(add_shapes=True))
return mod


def test_moments():
g = tf.Graph()
shape = [4, 176, 8, 8]
dtype = "float32"
with g.as_default():
A = tf.placeholder(shape=shape, dtype=dtype, name="A")
B = tf.placeholder(shape=shape, dtype=dtype, name="B")
mean, variance = tf.nn.moments(A, [1], keep_dims=True)
normalised_input = (A - mean) / tf.sqrt(variance + 0.0005)

mod = run_from_tensorflow(g)
program = """
def @main(%A: Tensor[(4, 176, 8, 8), float32]) {
%527 = mean(%A, axis=[1], keepdims=True) /* moments/mean */;
%528 = subtract(%A, %527) /* sub */;
%529 = subtract(%A, %527);
%530 = multiply(%529, %529) /* moments/SquaredDifference */;
%531 = mean(%530, axis=[1], keepdims=True) /* moments/variance */;
%532 = add(%531, 0.0005f) /* add */;
%533 = sqrt(%532) /* Sqrt */;
divide(%528, %533) /* truediv */
}
"""
mod_golden = tvm.parser.parse(SEMVER + program)
tvm.ir.assert_structural_equal(mod["main"].body, mod_golden["main"].body, map_free_vars=True)


if __name__ == "__main__":
pytest.main([__file__])