Skip to content

Commit 454ede4

Browse files
committed
add test cases
1 parent f24a7b2 commit 454ede4

File tree

12 files changed

+364
-94
lines changed

12 files changed

+364
-94
lines changed

include/tvm/relax/dataflow_pattern.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class NotPattern;
5656
class ShapePattern;
5757
class TypePattern;
5858
class DataTypePattern;
59+
class TargetPattern;
5960
class AttrPattern;
6061
class SameShapeConstraint;
6162

@@ -120,6 +121,8 @@ class DFPattern : public ObjectRef {
120121
TVM_DLL DataTypePattern HasDtype(const std::string& dtype) const;
121122
/*! \brief Syntatic Sugar for creating a ShapePattern */
122123
TVM_DLL ShapePattern HasShape(const Array<PrimExpr>& shape) const;
124+
/*! \brief Syntatic Sugar for creating a TargetPattern with a target */
125+
TVM_DLL TargetPattern HasTarget(const Target& target) const;
123126
/*! \brief Syntatic Sugar for creating a ShapePattern */
124127
TVM_DLL SameShapeConstraint HasSameShapeAs(const DFPattern& other) const;
125128
/*! \brief Syntatic Sugar for duplicating the current pattern */
@@ -843,6 +846,34 @@ class DataTypePattern : public DFPattern {
843846
TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode);
844847
};
845848

849+
/*!
850+
* \brief A pattern that asserting a root pattern has a certain target.
851+
* \sa TargetPattern
852+
*/
853+
class TargetPatternNode : public DFPatternNode {
854+
public:
855+
DFPattern pattern; /*!< The root pattern to match */
856+
Target target; /*!< The target to match */
857+
858+
void VisitAttrs(tvm::AttrVisitor* v) {
859+
v->Visit("pattern", &pattern);
860+
v->Visit("target", &target);
861+
}
862+
863+
static constexpr const char* _type_key = "relax.dpl.TargetPattern";
864+
TVM_DECLARE_FINAL_OBJECT_INFO(TargetPatternNode, DFPatternNode);
865+
};
866+
867+
/*!
868+
* \brief Managed reference to TargetPatternNode.
869+
* \sa TargetPatternNode
870+
*/
871+
class TargetPattern : public DFPattern {
872+
public:
873+
TVM_DLL TargetPattern(DFPattern pattern, Target target);
874+
TVM_DEFINE_OBJECT_REF_METHODS(TargetPattern, DFPattern, TargetPatternNode);
875+
};
876+
846877
/*!
847878
* \brief A pattern that asserting a root pattern has certain attributes.
848879
* \sa AttrPattern

include/tvm/relax/dataflow_pattern_functor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
9191
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
9292
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
9393
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
94+
virtual R VisitDFPattern_(const TargetPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
9495
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
9596
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
9697
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT;
@@ -127,6 +128,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
127128
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
128129
RELAX_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
129130
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
131+
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TargetPatternNode);
130132
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
131133
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
132134
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
@@ -161,6 +163,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const DFPattern&)> {
161163
void VisitDFPattern_(const ExprPatternNode* op) override;
162164
void VisitDFPattern_(const FunctionPatternNode* op) override;
163165
void VisitDFPattern_(const ShapePatternNode* op) override;
166+
void VisitDFPattern_(const TargetPatternNode* op) override;
164167
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
165168
void VisitDFPattern_(const TuplePatternNode* op) override;
166169
void VisitDFPattern_(const TypePatternNode* op) override;

python/tvm/relax/backend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@
1818

1919
from . import contrib
2020
from .pattern_registry import get_pattern, get_patterns_with_prefix
21+
from .dispatch_ops import DispatchOps

python/tvm/relax/backend/dispatch_ops.py

Lines changed: 78 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm.relax.dpl import is_op, rewrite_call, wildcard, has_target
2525

2626
from ..transform import function_pass
27+
from tvm.topi.utils import prod, swap
2728

2829

2930
@function_pass(opt_level=0)
@@ -35,10 +36,21 @@ class DispatchOps:
3536
def __init__(self):
3637
self.input = wildcard()
3738
# cumsum on cpu will be legalized
39+
self.cumsum_cpu = is_op("relax.cumsum")(self.input) & has_target("llvm")
3840
self.cumsum_gpu = is_op("relax.cumsum")(self.input) & has_target("cuda")
39-
self.sort_gpu = is_op("relax.sort")(self.input) & has_target("cuda")
4041
self.sort_cpu = is_op("relax.sort")(self.input) & has_target("llvm")
41-
self.pattern = self.cumsum_gpu | self.sort_gpu | self.sort_cpu
42+
self.sort_gpu = is_op("relax.sort")(self.input) & has_target("cuda")
43+
# if no target is specified, default will be on GPU
44+
self.sort = is_op("relax.sort")(self.input)
45+
self.cumsum = is_op("relax.cumsum")(self.input)
46+
self.pattern = (
47+
self.cumsum_gpu
48+
| self.cumsum_cpu
49+
| self.sort_gpu
50+
| self.sort_cpu
51+
| self.sort
52+
| self.cumsum
53+
)
4254

4355
def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRModule:
4456
"""
@@ -64,55 +76,81 @@ def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) -> IRM
6476
continue
6577

6678
def rewriter(expr, matches):
67-
print("got here 70, expr: ", expr)
6879
arg = matches[self.input]
69-
print("got arg: ", arg)
7080

71-
if self.cumsum_gpu in matches:
72-
print("86 matches[self.no_op_reshape]: ", matches[self.cumsum_gpu])
73-
return relax.call_dps_packed(
81+
if self.cumsum_gpu in matches or (
82+
self.cumsum in matches and self.cumsum_cpu not in matches
83+
):
84+
axis = matches[self.cumsum_gpu].attrs.axis
85+
output_dtype = matches[self.cumsum_gpu].attrs.dtype
86+
if output_dtype is None:
87+
output_dtype = out_sinfo.dtype
88+
out_sinfo = arg.struct_info
89+
if axis is None:
90+
axis = 0
91+
new_shape = (prod(arg.struct_info.shape),)
92+
arg = relax.op.reshape(arg, new_shape)
93+
out_sinfo = relax.TensorStructInfo(
94+
new_shape, output_dtype, out_sinfo.vdevice
95+
)
96+
return relax.op.call_dps_packed(
7497
"tvm.contrib.thrust.sum_scan",
75-
[arg],
76-
out_sinfo=arg.struct_info,
77-
)
78-
elif self.sort_gpu in matches:
79-
print("86 matches[self.no_op_reshape]: ", matches[self.sort_gpu])
80-
return relax.call_dps_packed(
81-
"tvm.contrib.thrust.sort",
82-
[arg],
83-
out_sinfo=arg.struct_info,
98+
[arg, int(axis)],
99+
out_sinfo=out_sinfo,
84100
)
101+
85102
elif self.sort_cpu in matches:
103+
axis = int(matches[self.sort_cpu].attrs.axis)
104+
is_ascend = int(matches[self.sort_cpu].attrs.is_ascend)
105+
out_sinfo = arg.struct_info
106+
if axis is None:
107+
axis = 0
108+
new_shape = (prod(arg.struct_info.shape),)
109+
out_sinfo = relax.TensorStructInfo(
110+
new_shape, arg.struct_info.dtype, out_sinfo.vdevice
111+
)
112+
86113
return relax.call_dps_packed(
87114
"tvm.contrib.sort.sort",
88-
[arg],
89-
out_sinfo=arg.struct_info,
115+
[arg, axis, is_ascend],
116+
out_sinfo=out_sinfo,
90117
)
91118

119+
elif self.sort_gpu in matches or self.sort in matches:
120+
axis = matches[self.sort_gpu].attrs.axis
121+
if axis is None:
122+
axis = -1
123+
axis = int(axis)
124+
125+
is_ascend = matches[self.sort_gpu].attrs.is_ascend
126+
if is_ascend is None:
127+
is_ascend = True
128+
out_sinfo = arg.struct_info
129+
ndim = arg.struct_info.ndim
130+
131+
axis = ndim + axis if axis < 0 else axis
132+
if axis != ndim - 1:
133+
# Prepare for sorting along axis -1.
134+
axes = swap(list(range(ndim)), axis)
135+
arg = relax.op.permute_dims(arg, axes)
136+
new_shape = [out_sinfo.shape[i] for i in axes]
137+
out_sinfo = relax.TensorStructInfo(
138+
new_shape, out_sinfo.dtype, out_sinfo.vdevice
139+
)
140+
141+
out = relax.op.call_dps_packed(
142+
"tvm.contrib.thrust.sort",
143+
[arg, int(is_ascend)],
144+
out_sinfo=out_sinfo,
145+
)
146+
if axis != ndim - 1:
147+
# Prepare for sorting along axis -1.
148+
axes = swap(list(range(ndim)), axis)
149+
out = relax.op.permute_dims(out, axes)
150+
return out
151+
92152
return expr
93153

94154
updated_func = rewrite_call(self.pattern, rewriter, func)
95155

96156
return updated_func
97-
98-
99-
# Option 0): add a global dict for it: {op, target, condition, dps_packed},
100-
# condition is some specific setting like the value of k in topk
101-
# Q: how to work it with pattern match?
102-
#
103-
# Option 1): normal python mod pass, straightforward, but not easy to hack like topk
104-
#
105-
# Option 2): c++ pass, not easy to be updated. Don't go
106-
#
107-
# How to handle with target? don't require RealizeVDevice, just specify the vdevice in inputs
108-
# but vdevice is necessary, we could have default for it
109-
#
110-
# Sample map
111-
# cumsum - cpu => ignore for legalization
112-
# cumsum - gpu => relax.call_dps_packed(
113-
# "tvm.contrib.thrust.sum_scan",
114-
# [data],
115-
# out_sinfo=data.struct_info,
116-
# )
117-
# f32_233 = wildcard().has_shape((2, 3, 3)) & has_dtype("float32") # and pattern
118-
# is_op("nn.conv2d")(xp, yp).has_attr({"kernel_size": [4, 3]}).match(conv2d)

python/tvm/relax/dpl/pattern.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,22 @@ def has_shape(self, shape: List[PrimExpr]) -> "ShapePattern":
177177
raise ValueError("has_shape takes a list or tuple as input.")
178178
return ShapePattern(pattern=self, shape=shape)
179179

180+
def has_target(self, target: tvm.target.Target) -> "TargetPattern":
181+
"""
182+
Add a target constraint to this pattern
183+
184+
Parameters
185+
----------
186+
target: Target
187+
The target to match
188+
189+
Returns
190+
-------
191+
result: TargetPattern
192+
The resulting TargetPattern
193+
"""
194+
return has_target(target, self)
195+
180196
def match(self, expr, var2val: Optional[Dict[Var, Expr]] = None) -> bool:
181197
"""
182198
Match a relax.Expr syntactically
@@ -626,6 +642,23 @@ def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]):
626642
self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) # type: ignore
627643

628644

645+
@register_df_node
646+
class TargetPattern(DFPattern):
647+
"""A pattern that matches another pattern with a certain target
648+
649+
Parameters
650+
----------
651+
pattern: tvm.relax.dpl.DFPattern
652+
The input pattern that needs type annotation.
653+
654+
target:
655+
The target to match.
656+
"""
657+
658+
def __init__(self, pattern: "DFPattern", target: tvm.target.Target):
659+
self.__init_handle_by_constructor__(ffi.TargetPattern, pattern, target) # type: ignore
660+
661+
629662
@register_df_node
630663
class SameShapeConstraint(DFConstraint):
631664
"""A pattern that requires a set of patterns to have the same shape
@@ -837,6 +870,30 @@ def has_dtype(dtype: str, pattern: DFPattern = None) -> DataTypePattern:
837870
return DataTypePattern(pattern, dtype)
838871

839872

873+
def has_target(target: tvm.target.Target, pattern: DFPattern = None) -> TargetPattern:
874+
"""
875+
Syntatic sugar for creating a TargetPattern
876+
877+
Parameters
878+
----------
879+
target: Target
880+
The target to match
881+
882+
pattern: tvm.relax.dpl.DFPattern
883+
The pattern that needs type annotation
884+
885+
Returns
886+
-------
887+
result: tvm.relax.dpl.TargetPattern
888+
The resulting TargetPattern
889+
"""
890+
if pattern is None:
891+
pattern = wildcard()
892+
if isinstance(target, (dict, str)):
893+
target = tvm.target.Target(tvm.runtime.convert(target))
894+
return TargetPattern(pattern, target)
895+
896+
840897
def is_shape(shape: List[tvm.ir.PrimExpr]) -> "PrimArrPattern":
841898
"""
842899
Directly matches a shape which is an array of PrimExpr

src/relax/ir/dataflow_matcher.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,15 @@ bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr
525525
return false;
526526
}
527527

528+
bool DFPatternMatcher::VisitDFPattern_(const TargetPatternNode* op, const Expr& expr) {
529+
const auto* sinfo = GetStructInfoAs<TensorStructInfoNode>(expr);
530+
if (sinfo->vdevice.defined()) {
531+
VDevice vdev = sinfo->vdevice.value();
532+
return vdev->target->kind == op->target->kind;
533+
}
534+
return false;
535+
}
536+
528537
bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
529538
// We don't jump for var pattern, as there's no need to access its value to judge it.
530539
if (const auto* var_node = expr.as<VarNode>()) {

src/relax/ir/dataflow_matcher_impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
5757
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
5858
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) override;
5959
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
60+
bool VisitDFPattern_(const TargetPatternNode* op, const Expr& expr) override;
6061
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) override;
6162
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
6263
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;

src/relax/ir/dataflow_pattern.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,21 @@ RELAX_PATTERN_PRINTER_DEF(DataTypePatternNode, [](auto p, auto node) {
315315
p->stream << "DataTypePattern(" << node->pattern << " has dtype " << node->dtype << ")";
316316
});
317317

318+
TVM_REGISTER_NODE_TYPE(TargetPatternNode);
319+
TargetPattern::TargetPattern(DFPattern pattern, Target target) {
320+
ObjectPtr<TargetPatternNode> n = make_object<TargetPatternNode>();
321+
n->pattern = std::move(pattern);
322+
n->target = std::move(target);
323+
data_ = std::move(n);
324+
}
325+
TVM_REGISTER_GLOBAL("relax.dpl.TargetPattern")
326+
.set_body_typed([](DFPattern pattern, Target target) {
327+
return TargetPattern(pattern, target);
328+
});
329+
RELAX_PATTERN_PRINTER_DEF(TargetPatternNode, [](auto p, auto node) {
330+
p->stream << "TargetPattern(" << node->pattern << " has target " << node->target << ")";
331+
});
332+
318333
TVM_REGISTER_NODE_TYPE(AttrPatternNode);
319334
AttrPattern::AttrPattern(DFPattern pattern, DictAttrs attrs) {
320335
ObjectPtr<AttrPatternNode> n = make_object<AttrPatternNode>();
@@ -367,6 +382,9 @@ class DFPatternDuplicator : public DFPatternFunctor<DFPattern(const DFPattern&)>
367382
DFPattern VisitDFPattern_(const DataTypePatternNode* op) override {
368383
return DataTypePattern(op->pattern, op->dtype);
369384
}
385+
DFPattern VisitDFPattern_(const TargetPatternNode* op) override {
386+
return TargetPattern(op->pattern, op->target);
387+
}
370388
DFPattern VisitDFPattern_(const FunctionPatternNode* op) override {
371389
return FunctionPattern(op->params, op->body);
372390
}
@@ -410,6 +428,9 @@ DataTypePattern DFPattern::HasDtype(const std::string& dtype) const {
410428
ShapePattern DFPattern::HasShape(const Array<PrimExpr>& shape) const {
411429
return ShapePattern(*this, shape);
412430
}
431+
TargetPattern DFPattern::HasTarget(const Target& target) const {
432+
return TargetPattern(*this, target);
433+
}
413434

414435
DFPattern::operator PatternSeq() const { return PatternSeq{{*this}}; }
415436

src/relax/ir/dataflow_pattern_functor.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) {
6363
VisitDFPattern(op->pattern);
6464
}
6565

66+
void DFPatternVisitor::VisitDFPattern_(const TargetPatternNode* op) {
67+
VisitDFPattern(op->pattern);
68+
}
69+
6670
void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}
6771

6872
void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {

src/script/printer/relax/utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const IRDocsifie
109109
kind_index++;
110110
}
111111
}
112-
LOG(WARNING) << "The VDevice was not found in the global_infos map: " << vdevice;
113112
return -1;
114113
}
115114

0 commit comments

Comments
 (0)