-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Tensorflow] Allow an op as loop var. #3056
Conversation
70c1f3f
to
dd13c83
Compare
Look good to me. @jroesch can you too confirm on the IR part ? |
Thanks for the contribution. But I am not sure if we really want to do this because it looks to me that ANF should be able to avoid the mentioned problem. @jroesch please confirm. |
@zhiics I think this is fine, we shouldn't need to make any changes to bind though, the original code was just generating a single loop variable for each sub-graph used as a loop variable this time. We can later ANF the program too. |
src/relay/ir/expr_functor.cc
Outdated
@@ -383,11 +383,21 @@ class ExprBinder : public ExprMutator { | |||
} | |||
} | |||
|
|||
Expr VisitExpr_(const CallNode* op) final { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need these changes we don't really want to do arbitrary expression to expression replacement.
else: | ||
var_type = var.type_annotation | ||
|
||
v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type) | ||
loop_vars.append(v) | ||
bind_map[var] = v |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think instead of calling bind which has specific semantics we should just add code like this:
class RewriteSubgraph(ExprMutator):
def __init__(self, rewrite_map):
self.rewrite_map = rewrite_map
def visit(self, expr):
if expr in self.rewrite_map:
return self.rewrite_map[expr]
else:
return super().visit(expr)
def rewrite_subgraph(expr, rewrites):
return RewriteSubgraph(rewrites).visit(expr)
then replace the call with bind to this, this will handle the general case and not just calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Do you think it should be a common function or just be kept in TF frontend?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is reasonable to just put it in the TF pass for now. We could also put in the frontend/common file where utilities are kept.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've modified it according to request.
Let me know when its ready for review agian. |
b03c6c2
to
7cf6277
Compare
@jroesch Will you please review latest change? |
Great looks good, thanks! |
Allow binding a CallNode to a var to support op as loop var.
@zhiics @srkreddy1238 @jroesch Could you please review?