-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix a bug in nnvm to relay converter. #2756
Conversation
line 455 checks whether a child is expr.TupleWrapper, so _split() should return TupleWrapper instead of TupleWrapper.tuple_value 455 if isinstance(child, expr.TupleWrapper): 456 children.append(child[i[1]])
@lixiaoquan, I did the change in #2734 and it looks wrong. Thanks for the catch. To pass the CI error, I think we also need a fix to handle TupleWrapper in outputs of the graph. |
Can we add a regression test? These kind of regressions are usually due to lacking tests, and it would be good to guard against this error in the future. |
@kazum I refine it according to your suggestion. But there is another CI error which I can't reproduce locally. Could you check my patch? Thanks. |
nnvm/python/nnvm/to_relay.py
Outdated
@@ -480,6 +480,7 @@ def to_relay(graph, shape_dict, dtype_dict, params): | |||
"nnvm.to_relay: unsupported operator: {0}".format(op_name)) | |||
|
|||
outputs = [relay_map[nid] for nid in output_ids] | |||
outputs = [x if not isinstance(x, expr.TupleWrapper) else x.astuple() for x in outputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not enough. Here is an example we also need to consider.
import nnvm
from nnvm.to_relay import to_relay
x = nnvm.sym.Variable("x")
y = nnvm.sym.split(x, indices_or_sections=2)
z = y[1]
graph = nnvm.graph.create(z)
print(graph.ir())
# Graph(%x) {
# %1 = split(%x, indices_or_sections='2')
# ret %1.1
# }
func, _ = to_relay(graph, {}, {}, {})
print(func)
# fn (%x: ) {
# %0 = split(%x, indices_or_sections=int64(2))
# %0 <= should be %0.1
# }
In the case of TupleWrapper, I think we should return x[index]
instead of x.astuple()
where index
can be taken from json.loads(graph.json())['heads']
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those cases are also handled now
@@ -31,6 +31,23 @@ def check_model(sym, shapes, dtypes, params): | |||
relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values())) | |||
np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy()) | |||
|
|||
|
|||
def test_split_concatenate(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this test should be in tests/python/frontend/nnvm_to_relay/test_forward.py.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified
@lixiaoquan I added some comments. I think the current CI error is not related to your change. I'll look into it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added final comments. I'll approve after they are addressed.
nnvm/python/nnvm/to_relay.py
Outdated
@@ -479,7 +480,14 @@ def to_relay(graph, shape_dict, dtype_dict, params): | |||
raise Exception( | |||
"nnvm.to_relay: unsupported operator: {0}".format(op_name)) | |||
|
|||
outputs = [relay_map[nid] for nid in output_ids] | |||
outputs = [] | |||
for i in output_ids: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for i in output_ids: | |
for nid, idx, _ in gidx.output_entries: |
Then, we can remove heads
and output_ids
, and replace i[0]
and i[1]
with nid
and gid
for better readability.
verify_nnvm_to_relay(splited, params, data_shape=shape) | ||
verify_nnvm_to_relay(concatenated, params, data_shape=shape) | ||
|
||
|
||
if __name__ == '__main__': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test_forward_split_concatenate() here so that the test will be executed when we run the this script directly.
64e1c8b
to
7955a82
Compare
Add a regression test guarding on original bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks!
Thanks @kazum @lixiaoquan @jroesch , this is now merged |
line 455 checks whether a child is expr.TupleWrapper, so _split() should
return TupleWrapper instead of TupleWrapper.tuple_value
455 if isinstance(child, expr.TupleWrapper):
456 children.append(child[i[1]])