Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,23 @@ class ExprMutator(ExprFunctor):
def visit_function(self, fn):
new_params = [self.visit(x) for x in fn.params]
new_body = self.visit(fn.body)
if new_params == list(fn.params) and new_body == fn.body:
return fn
return FunctionWithFields(fn, list(new_params), new_body)

def visit_let(self, let):
new_var = self.visit(let.var)
new_val = self.visit(let.value)
new_body = self.visit(let.body)
if new_var == let.var and new_val == let.value and new_body == let.body:
return let
return Let(new_var, new_val, new_body)

def visit_call(self, call):
new_fn = self.visit(call.op)
new_args = [self.visit(arg) for arg in call.args]
if new_fn == call.op and new_args == list(call.args):
return call
return Call(new_fn, new_args, call.attrs, call.type_args, call.span)

def visit_var(self, var):
Expand All @@ -224,16 +230,28 @@ def visit_global_id(self, global_var):
return global_var

def visit_if(self, ite):
return If(self.visit(ite.cond), self.visit(ite.true_branch), self.visit(ite.false_branch))
new_cond = self.visit(ite.cond)
new_true_branch = self.visit(ite.true_branch)
new_false_branch = self.visit(ite.false_branch)
if (
new_cond == ite.cond
and new_true_branch == ite.true_branch
and new_false_branch == ite.false_branch
):
return ite
return If(new_cond, new_true_branch, new_false_branch)

def visit_tuple(self, tup):
return Tuple([self.visit(field) for field in tup.fields], tup.span)
new_fields = [self.visit(field) for field in tup.fields]
if new_fields == list(tup.fields):
return tup
return Tuple(new_fields, tup.span)

def visit_tuple_getitem(self, op):
tuple_value = self.visit(op.tuple_value)
if not tuple_value.same_as(op.tuple_value):
return TupleGetItem(tuple_value, op.index)
return op
new_tuple_value = self.visit(op.tuple_value)
if new_tuple_value == op.tuple_value:
return op
return TupleGetItem(new_tuple_value, op.index)

def visit_global_var(self, gvar):
return gvar
Expand All @@ -248,17 +266,27 @@ def visit_constructor(self, con):
return con

def visit_match(self, m):
return Match(
self.visit(m.data),
[Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses],
complete=m.complete,
)
new_data = self.visit(m.data)
new_clauses = [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses]
if new_data == m.data and all(x.rhs == y.rhs for x, y in zip(new_clauses, m.clauses)):
return m
return Match(new_data, new_clauses, complete=m.complete)

def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))
new_value = self.visit(r.value)
if new_value == r.value:
return r
return RefCreate(new_value)

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
new_ref = self.visit(r.ref)
new_value = self.visit(r.value)
if new_ref == r.ref and new_value == r.value:
return r
return RefWrite(new_ref, new_value)

def visit_ref_read(self, r):
return RefRead(self.visit(r.ref))
new_ref = self.visit(r.ref)
if new_ref == r.ref:
return r
return RefRead(new_ref)
2 changes: 1 addition & 1 deletion tests/python/relay/test_expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def check_visit(expr):
ev.visit(expr)

em = ExprMutator()
assert em.visit(expr)
assert expr == em.visit(expr)


def test_constant():
Expand Down