Skip to content

Commit

Permalink
[Relay] add Tuple pattern (apache#3596)
Browse files Browse the repository at this point in the history
* implement tuple pattern

* add tuple pattern

* lint;

* lint

* lint

* fix error

* fix

* add test
  • Loading branch information
MarisaKirisame authored and wweic committed Sep 16, 2019
1 parent e84969c commit e7c3ff1
Show file tree
Hide file tree
Showing 18 changed files with 340 additions and 42 deletions.
23 changes: 23 additions & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ class PatternConstructorNode : public PatternNode {

RELAY_DEFINE_NODE_REF(PatternConstructor, PatternConstructorNode, Pattern);

/*! \brief A tuple pattern. Matches a tuple, binds recursively. */
class PatternTuple;
/*! \brief PatternVar container node */
class PatternTupleNode : public PatternNode {
public:
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;

PatternTupleNode() {}

TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
}

static constexpr const char* _type_key = "relay.PatternTuple";
TVM_DECLARE_NODE_TYPE_INFO(PatternTupleNode, PatternNode);
};

RELAY_DEFINE_NODE_REF(PatternTuple, PatternTupleNode, Pattern);

/*!
* \brief Stores all data for an Algebraic Data Type (ADT).
*
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/relay/pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternConstructorNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPattern_(const PatternTupleNode* op,
Args... args) PATTERN_FUNCTOR_DEFAULT;
virtual R VisitPatternDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -112,6 +114,7 @@ class PatternFunctor<R(const Pattern& n, Args...)> {
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternWildcardNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternVarNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternConstructorNode);
RELAY_PATTERN_FUNCTOR_DISPATCH(PatternTupleNode);
return vtable;
}
};
Expand All @@ -127,6 +130,7 @@ class PatternVisitor : public ::tvm::relay::PatternFunctor<void(const Pattern& n
void VisitPattern_(const PatternWildcardNode* op) override;
void VisitPattern_(const PatternVarNode* op) override;
void VisitPattern_(const PatternConstructorNode* op) override;
void VisitPattern_(const PatternTupleNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitVar(const Var& v);
virtual void VisitConstructor(const Constructor& c);
Expand All @@ -144,6 +148,7 @@ class PatternMutator
Pattern VisitPattern_(const PatternWildcardNode* op) override;
Pattern VisitPattern_(const PatternVarNode* op) override;
Pattern VisitPattern_(const PatternConstructorNode* op) override;
Pattern VisitPattern_(const PatternTupleNode* op) override;
/*! \brief Used to visit the types inside of patterns.
*
* Can be overloaded to transform the types in arbitrary
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
PatternWildcard = adt.PatternWildcard
PatternVar = adt.PatternVar
PatternConstructor = adt.PatternConstructor
PatternTuple = adt.PatternTuple
Constructor = adt.Constructor
TypeData = adt.TypeData
Clause = adt.Clause
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,29 @@ def __init__(self, constructor, patterns=None):
self.__init_handle_by_constructor__(_make.PatternConstructor, constructor, patterns)


@register_relay_node
class PatternTuple(Pattern):
"""Constructor pattern in Relay: Matches a tuple, binds recursively."""

def __init__(self, patterns=None):
"""Construct a tuple pattern.
Parameters
----------
patterns: Optional[List[Pattern]]
Optional subpatterns: for each field of the constructor,
match to the given subpattern (treated as a variable pattern by default).
Returns
-------
wildcard: PatternWildcard
a wildcard pattern.
"""
if patterns is None:
patterns = []
self.__init_handle_by_constructor__(_make.PatternTuple, patterns)


@register_relay_node
class Constructor(Expr):
"""Relay ADT constructor."""
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .expr import Var, Function, GlobalVar, Let, If, Tuple, TupleGetItem, const
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
from .adt import PatternConstructor, PatternVar, PatternWildcard, PatternTuple
from .parser import fromtext
__PRELUDE_PATH__ = os.path.dirname(os.path.realpath(__file__))
from .module import Module
Expand Down Expand Up @@ -239,18 +239,19 @@ def define_list_zip(self):
self.zip = GlobalVar("zip")
a = TypeVar("a")
b = TypeVar("b")
nil_case = Clause(PatternConstructor(self.nil), self.nil())
l1 = Var("l1")
l2 = Var("l2")
h1 = Var("h1")
h2 = Var("h2")
t1 = Var("t1")
t2 = Var("t2")
inner_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h2), PatternVar(t2)]),
self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
outer_cons_case = Clause(PatternConstructor(self.cons, [PatternVar(h1), PatternVar(t1)]),
Match(l2, [nil_case, inner_cons_case]))
self.mod[self.zip] = Function([l1, l2], Match(l1, [nil_case, outer_cons_case]),
cons_case = Clause(PatternTuple([PatternConstructor(self.cons,
[PatternVar(h1), PatternVar(t1)]),
PatternConstructor(self.cons,
[PatternVar(h2), PatternVar(t2)])]),
self.cons(Tuple([h1, h2]), self.zip(t1, t2)))
nil_case = Clause(PatternWildcard(), self.nil())
self.mod[self.zip] = Function([l1, l2], Match(Tuple([l1, l2]), [cons_case, nil_case]),
self.l(TupleType([a, b])), [a, b])


Expand Down
16 changes: 10 additions & 6 deletions python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,14 +311,18 @@ def create_match_check(self, pattern: Pattern, data):
if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
return NameConstant(True)

# constructor patterns check whether the constructors match
# and also the matches of any nested patterns
conds = []

# equiv: (arg.tag == patern_constructor.tag)
conds = [ast.Compare(ast.Attribute(data, 'tag', Load()),
[ast.Eq()],
[ast.Num(pattern.constructor.tag)])]
if isinstance(pattern, relay.PatternConstructor):
# constructor patterns check whether the constructors match
# and also the matches of any nested patterns

# equiv: (arg.tag == patern_constructor.tag)
conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
[ast.Eq()],
[ast.Num(pattern.constructor.tag)]))

assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
# now check for any nested patterns
for i in range(len(pattern.patterns)):
nested_pat = pattern.patterns[i]
Expand Down
14 changes: 13 additions & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/interpreter.cc
* \brief An interpreter for the Relay IR.
*/
Expand Down Expand Up @@ -708,6 +708,18 @@ class Interpreter :
return false;
}

bool VisitPattern_(const PatternTupleNode* op, const Value& v) final {
const TupleValueNode* tvn = v.as<TupleValueNode>();
CHECK(tvn) << "need to be a tuple for match";
CHECK_EQ(op->patterns.size(), tvn->fields.size());
for (size_t i = 0; i < op->patterns.size(); ++i) {
if (!VisitPattern(op->patterns[i], tvn->fields[i])) {
return false;
}
}
return true;
}

bool VisitPattern_(const PatternWildcardNode* op, const Value& v) final {
return true;
}
Expand Down
18 changes: 13 additions & 5 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,27 @@ TreeNodePtr BuildDecisionTreeFromPattern(MatchValuePtr data,
auto pattern = GetRef<PatternVar>(pat);
auto cond = std::make_shared<VarBinding>(pattern->var, data);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pat = pattern.as<PatternConstructorNode>();
auto pattern = GetRef<PatternConstructor>(pat);
auto tag = pattern->constructor->tag;
} else if (auto pcn = pattern.as<PatternConstructorNode>()) {
auto tag = pcn->constructor->tag;

size_t field_index = 0;
for (auto& p : pattern->patterns) {
for (auto& p : pcn->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
auto cond = std::make_shared<TagCompare>(data, tag);
return TreeBranchNode::Make(cond, then_branch, else_branch);
} else {
auto pt = pattern.as<PatternTupleNode>();
CHECK(pt) << "unhandled case: " << pattern;
size_t field_index = 0;
for (auto& p : pt->patterns) {
auto d = std::make_shared<AccessField>(data, field_index);
then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
field_index++;
}
return then_branch;
}
}

Expand Down
17 changes: 17 additions & 0 deletions src/relay/ir/adt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->patterns << ")";
});

PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
NodePtr<PatternTupleNode> n = make_node<PatternTupleNode>();
n->patterns = std::move(patterns);
return PatternTuple(n);
}

TVM_REGISTER_NODE_TYPE(PatternTupleNode);

TVM_REGISTER_API("relay._make.PatternTuple")
.set_body_typed(PatternTupleNode::make);

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const PatternTupleNode* node,
tvm::IRPrinter* p) {
p->stream << "PatternTupleNode(" << node->patterns << ")";
});

Constructor ConstructorNode::make(std::string name_hint,
tvm::Array<Type> inputs,
GlobalTypeVar belong_to) {
Expand Down
17 changes: 16 additions & 1 deletion src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ class AlphaEqualHandler:
}

bool PatternEqual(const Pattern& lhs, const Pattern& rhs) {
return VisitPattern(lhs, rhs);
return Compare(VisitPattern(lhs, rhs), lhs, rhs);
}

bool VisitPattern_(const PatternWildcardNode* lhs, const Pattern& other) final {
Expand Down Expand Up @@ -523,6 +523,21 @@ class AlphaEqualHandler:
return true;
}

bool VisitPattern_(const PatternTupleNode* lhs, const Pattern& other) final {
const auto* rhs = other.as<PatternTupleNode>();
if (rhs == nullptr
|| lhs->patterns.size() != rhs->patterns.size()) {
return false;
}

for (size_t i = 0; i < lhs->patterns.size(); i++) {
if (!PatternEqual(lhs->patterns[i], rhs->patterns[i])) {
return false;
}
}
return true;
}

bool VisitExpr_(const MatchNode* lhs, const Expr& other) final {
const MatchNode* rhs = other.as<MatchNode>();

Expand Down
8 changes: 8 additions & 0 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,14 @@ class RelayHashHandler:
return hash;
}

size_t VisitPattern_(const PatternTupleNode* ptn) final {
size_t hash = std::hash<std::string>()(PatternTupleNode::_type_key);
for (const auto& p : ptn->patterns) {
hash = Combine(hash, PatternHash(p));
}
return hash;
}

size_t VisitPattern_(const PatternVarNode* pvn) final {
size_t hash = std::hash<std::string>()(PatternVarNode::_type_key);
hash = Combine(hash, BindVar(pvn->var));
Expand Down
22 changes: 18 additions & 4 deletions src/relay/ir/pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -18,8 +18,8 @@
*/

/*!
* Copyright (c) 2018 by Contributors
* \file src/tvm/relay/pattern_functor.cc
* Copyright (c) 2019 by Contributors
* \file src/relay/ir/pattern_functor.cc
* \brief Implementations of visitors and mutators for ADT patterns.
*/

Expand Down Expand Up @@ -48,6 +48,14 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
}

Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
std::vector<Pattern> pat;
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternTupleNode::make(pat);
}

Type PatternMutator::VisitType(const Type& t) {
return t;
}
Expand Down Expand Up @@ -78,6 +86,12 @@ void PatternVisitor::VisitPattern_(const PatternConstructorNode* op) {
}
}

void PatternVisitor::VisitPattern_(const PatternTupleNode* op) {
for (const auto& p : op->patterns) {
VisitPattern(p);
}
}

void PatternVisitor::VisitType(const Type& t) { }

void PatternVisitor::VisitVar(const Var& v) {
Expand Down
Loading

0 comments on commit e7c3ff1

Please sign in to comment.