Skip to content

Commit 02a60d0

Browse files
committed
[TOP] Add dense, batchnorm (#22)
* [TOP] Add dense, batchnorm * update tvm
1 parent b37e5c2 commit 02a60d0

File tree

14 files changed

+401
-213
lines changed

14 files changed

+401
-213
lines changed

nnvm/include/nnvm/compiler/op_attr_types.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,14 @@ using TOpPattern = int;
4444
* \brief Computation description interface
4545
* \param attrs The attribute of the node.
4646
* \param inputs The input tensors(placeholders)
47+
* \param out_info Tensors holding shape/type information about output,
48+
& these are always placeholders.
4749
* \return The output description of the tensor.
4850
*/
4951
using FTVMCompute = std::function<
50-
Array<Tensor>
51-
(const NodeAttrs& attrs, const Array<Tensor>& inputs)>;
52+
Array<Tensor>(const NodeAttrs& attrs,
53+
const Array<Tensor>& inputs,
54+
const Array<Tensor>& out_info)>;
5255

5356
/*!
5457
* \brief Build the computation schedule for

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,12 @@ def optimize(graph, shape, dtype="float32"):
115115
"""
116116
# pylint: disable=unused-argument
117117
cfg = BuildConfig.current
118+
graph = graph_attr.set_shape_inputs(graph, shape)
119+
graph = graph.apply("InferShape")
120+
if graph.json_attr("shape_num_unknown_nodes"):
121+
raise ValueError("InferShape fails..")
118122
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
119-
graph = graph_attr.set_shape_inputs(graph, shape)
120-
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
123+
graph = graph.apply("SimplifyBatchNormInference")
121124
return graph
122125

123126

@@ -164,6 +167,12 @@ def build(graph, target, shape, dtype="float32", params=None):
164167
cfg = BuildConfig.current
165168
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
166169
shape, dtype = _update_shape_dtype(shape, dtype, params)
170+
# Initial pass do shape type inference
171+
ishape, _ = graph_util.infer_shape(graph, **shape)
172+
shape.update(zip(graph.index.input_names, ishape))
173+
if not isinstance(dtype, str):
174+
idtype, _ = graph_util.infer_dtype(graph, **dtype)
175+
dtype.update(zip(graph.index.input_names, idtype))
167176
# Apply optimization
168177
graph = optimize(graph, shape, dtype)
169178
# Precompute prune

nnvm/python/nnvm/compiler/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
class OpPattern(object):
66
ELEM_WISE = 0
77
BROADCAST = 1
8+
# Complex means we can fuse elemwise to it
89
COMPLEX = 2
9-
EXTERN = 2
10+
# Extern means the op is not fusable
11+
EXTERN = 3
1012

1113
_register_compute = tvm.get_global_func("nnvm._register_compute")
1214
_register_schedule = tvm.get_global_func("nnvm._register_schedule")

nnvm/python/nnvm/top/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .attr_dict import AttrDict
33
from . import tensor
44
from . import nn
5+
from . import transform

nnvm/python/nnvm/top/nn.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
1+
# pylint: disable=invalid-name, unused-argument
12
"""Definition of nn ops"""
23
from __future__ import absolute_import
34

45
import tvm
56
import topi
67
from topi.util import get_const_int
7-
from .tensor import schedule_elemwise
8+
from .tensor import _fschedule_broadcast
89
from ..compiler import registry as reg
910
from ..compiler import OpPattern
1011

1112
# relu
1213
@reg.register_compute("relu")
13-
def compute_relu(_, inputs):
14+
def compute_relu(attrs, inputs, _):
1415
"""Compute definition of relu"""
1516
return topi.nn.relu(inputs[0])
1617

17-
@reg.register_schedule("relu")
18-
def schedule_relu(_, outs, target):
19-
"""Schedule definition of relu"""
20-
return schedule_elemwise(_, outs, target)
21-
18+
reg.register_schedule("relu", _fschedule_broadcast)
2219
reg.register_pattern("relu", OpPattern.ELEM_WISE)
2320

2421

22+
# flatten
23+
@reg.register_compute("flatten")
24+
def compute_flatten(attrs, inputs, _):
25+
"""Compute definition of flatten"""
26+
return topi.nn.flatten(inputs[0])
27+
28+
reg.register_schedule("flatten", _fschedule_broadcast)
29+
reg.register_pattern("flatten", OpPattern.COMPLEX)
30+
31+
2532
# softmax
2633
@reg.register_compute("softmax")
27-
def compute_softmax(attrs, inputs):
34+
def compute_softmax(attrs, inputs, _):
2835
"""Compute definition of softmax"""
2936
axis = attrs.get_int("axis")
3037
assert axis == -1, "only support axis == -1 for now"
@@ -38,12 +45,34 @@ def schedule_softmax(_, outs, target):
3845
# naive schedule
3946
return tvm.create_schedule([x.op for x in outs])
4047

41-
reg.register_pattern("softmax", OpPattern.COMPLEX)
48+
# Mark softmax as extern as we do not fuse it in call cases
49+
reg.register_pattern("softmax", OpPattern.EXTERN)
50+
51+
52+
# dense
53+
@reg.register_compute("dense")
54+
def compute_dense(attrs, inputs, _):
55+
"""Compute definition of dense"""
56+
if attrs.get_bool("use_bias"):
57+
return topi.nn.fully_connected_with_bias(
58+
inputs[0], inputs[1], inputs[2])
59+
return topi.nn.fully_connected(inputs[0], inputs[1])
60+
61+
@reg.register_schedule("dense")
62+
def schedule_dense(_, outs, target):
63+
"""Schedule definition of dense"""
64+
if target == "cuda":
65+
raise ValueError("fully_connected not yet implemented")
66+
# naive schedule
67+
return tvm.create_schedule([x.op for x in outs])
68+
69+
# register extern for now, change me when fusion is enabled.
70+
reg.register_pattern("dense", OpPattern.EXTERN)
4271

4372

4473
# conv
4574
@reg.register_compute("conv2d")
46-
def compute_conv2d(attrs, inputs):
75+
def compute_conv2d(attrs, inputs, _):
4776
"""Compute definition of conv2d"""
4877
padding = attrs.get_int_tuple("padding")
4978
strides = attrs.get_int_tuple("strides")

nnvm/python/nnvm/top/tensor.py

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# pylint: disable=invalid-name
1+
# pylint: disable=invalid-name, unused-argument
22
"""Tensor ops"""
33
from __future__ import absolute_import
44

@@ -8,15 +8,6 @@
88
from ..compiler import registry as reg
99
from ..compiler import OpPattern
1010

11-
def schedule_elemwise(_, outs, target):
12-
"""Generic schedule for elemwise operation"""
13-
if target == "cuda":
14-
return topi.cuda.schedule_elemwise(outs)
15-
assert target.startswith("llvm")
16-
s = tvm.create_schedule([x.op for x in outs])
17-
tvm.schedule.AutoInlineInjective(s)
18-
return s
19-
2011
def _schedule_broadcast(_, outs, target):
2112
"""Generic schedule for binary bcast"""
2213
if target == "cuda":
@@ -29,66 +20,140 @@ def _schedule_broadcast(_, outs, target):
2920
def _compute_binary_scalar(f):
3021
"""auxiliary function"""
3122
@tvm.tag_scope("ewise")
32-
def _compute(attrs, x):
23+
def _compute(attrs, x, _):
3324
x = x[0]
3425
scalar = attrs.get_float("scalar")
3526
scalar = tvm.const(scalar, x.dtype)
3627
return tvm.compute(x.shape, lambda *i: f(x(*i), scalar))
3728
return _compute
3829

3930

31+
def _compute_unary(f):
32+
"""auxiliary function"""
33+
def _compute(attrs, x, _):
34+
return f(x[0])
35+
return _compute
36+
37+
38+
def _compute_binary(f):
39+
"""auxiliary function"""
40+
def _compute(attrs, x, _):
41+
return f(x[0], x[1])
42+
return _compute
43+
44+
4045
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
4146

4247
# exp
43-
reg.register_compute("exp",
44-
lambda _, x: topi.exp(x[0]))
48+
reg.register_compute("exp", _compute_unary(topi.exp))
4549
reg.register_pattern("exp", OpPattern.ELEM_WISE)
4650
reg.register_schedule("exp", _fschedule_broadcast)
4751

52+
# sqrt
53+
reg.register_compute("sqrt", _compute_unary(topi.sqrt))
54+
reg.register_pattern("sqrt", OpPattern.ELEM_WISE)
55+
reg.register_schedule("sqrt", _fschedule_broadcast)
56+
4857
# log
49-
reg.register_compute("log",
50-
lambda _, x: topi.log(x[0]))
58+
reg.register_compute("log", _compute_unary(topi.log))
5159
reg.register_pattern("log", OpPattern.ELEM_WISE)
5260
reg.register_schedule("log", _fschedule_broadcast)
5361

5462
# tanh
55-
reg.register_compute("tanh",
56-
lambda _, x: topi.tanh(x[0]))
63+
reg.register_compute("tanh", _compute_unary(topi.tanh))
5764
reg.register_pattern("tanh", OpPattern.ELEM_WISE)
5865
reg.register_schedule("tanh", _fschedule_broadcast)
5966

67+
# negative
68+
reg.register_compute("negative", _compute_unary(topi.negative))
69+
reg.register_pattern("negative", OpPattern.ELEM_WISE)
70+
reg.register_schedule("negative", _fschedule_broadcast)
71+
6072
# sigmoid
61-
reg.register_compute("sigmoid",
62-
lambda _, x: topi.sigmoid(x[0]))
73+
reg.register_compute("sigmoid", _compute_unary(topi.sigmoid))
6374
reg.register_pattern("sigmoid", OpPattern.ELEM_WISE)
6475
reg.register_schedule("sigmoid", _fschedule_broadcast)
6576

66-
# add scalar
77+
# add_scalar
6778
reg.register_compute("__add_scalar__",
6879
_compute_binary_scalar(lambda x, y: x + y))
6980
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
7081
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
7182

83+
# sub_calar
84+
reg.register_compute("__sub_scalar__",
85+
_compute_binary_scalar(lambda x, y: x - y))
86+
reg.register_pattern("__sub_scalar__", OpPattern.ELEM_WISE)
87+
reg.register_schedule("__sub_scalar__", _fschedule_broadcast)
88+
89+
# rsub_scalar
90+
reg.register_compute("__rsub_scalar__",
91+
_compute_binary_scalar(lambda x, y: y - x))
92+
reg.register_pattern("__rsub_scalar__", OpPattern.ELEM_WISE)
93+
reg.register_schedule("__rsub_scalar__", _fschedule_broadcast)
94+
95+
# mul_scalar
96+
reg.register_compute("__mul_scalar__",
97+
_compute_binary_scalar(lambda x, y: x * y))
98+
reg.register_pattern("__mul_scalar__", OpPattern.ELEM_WISE)
99+
reg.register_schedule("__mul_scalar__", _fschedule_broadcast)
100+
101+
# div_scalar
102+
reg.register_compute("__div_scalar__",
103+
_compute_binary_scalar(lambda x, y: x / y))
104+
reg.register_pattern("__div_scalar__", OpPattern.ELEM_WISE)
105+
reg.register_schedule("__div_scalar__", _fschedule_broadcast)
106+
107+
# rdiv_scalar
108+
reg.register_compute("__rdiv_scalar__",
109+
_compute_binary_scalar(lambda x, y: y / x))
110+
reg.register_pattern("__rdiv_scalar__", OpPattern.ELEM_WISE)
111+
reg.register_schedule("__rdiv_scalar__", _fschedule_broadcast)
112+
113+
# elemwise_add
114+
reg.register_compute("elemwise_add", _compute_binary(topi.broadcast_add))
115+
reg.register_pattern("elemwise_add", OpPattern.BROADCAST)
116+
reg.register_schedule("elemwise_add", _fschedule_broadcast)
117+
118+
# elemwise_sub
119+
reg.register_compute("elemwise_sub", _compute_binary(topi.broadcast_sub))
120+
reg.register_pattern("elemwise_sub", OpPattern.BROADCAST)
121+
reg.register_schedule("elemwise_sub", _fschedule_broadcast)
122+
123+
# elemwise_mul
124+
reg.register_compute("elemwise_mul", _compute_binary(topi.broadcast_mul))
125+
reg.register_pattern("elemwise_mul", OpPattern.BROADCAST)
126+
reg.register_schedule("elemwise_mul", _fschedule_broadcast)
127+
128+
# elemwise_div
129+
reg.register_compute("elemwise_div", _compute_binary(topi.broadcast_div))
130+
reg.register_pattern("elemwise_div", OpPattern.BROADCAST)
131+
reg.register_schedule("elemwise_div", _fschedule_broadcast)
132+
72133
# broadcast_add
73-
reg.register_compute("broadcast_add",
74-
lambda _, x: topi.broadcast_add(x[0], x[1]))
134+
reg.register_compute("broadcast_add", _compute_binary(topi.broadcast_add))
75135
reg.register_pattern("broadcast_add", OpPattern.BROADCAST)
76136
reg.register_schedule("broadcast_add", _fschedule_broadcast)
77137

78138
# broadcast_sub
79-
reg.register_compute("broadcast_sub",
80-
lambda _, x: topi.broadcast_sub(x[0], x[1]))
139+
reg.register_compute("broadcast_sub", _compute_binary(topi.broadcast_sub))
81140
reg.register_pattern("broadcast_sub", OpPattern.BROADCAST)
82141
reg.register_schedule("broadcast_sub", _fschedule_broadcast)
83142

84143
# broadcast_mul
85-
reg.register_compute("broadcast_mul",
86-
lambda _, x: topi.broadcast_mul(x[0], x[1]))
144+
reg.register_compute("broadcast_mul", _compute_binary(topi.broadcast_mul))
87145
reg.register_pattern("broadcast_mul", OpPattern.BROADCAST)
88146
reg.register_schedule("broadcast_mul", _fschedule_broadcast)
89147

90148
# broadcast_div
91-
reg.register_compute("broadcast_div",
92-
lambda _, x: topi.broadcast_div(x[0], x[1]))
149+
reg.register_compute("broadcast_div", _compute_binary(topi.broadcast_div))
93150
reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
94151
reg.register_schedule("broadcast_div", _fschedule_broadcast)
152+
153+
# broadcast_to
154+
@reg.register_compute("broadcast_to")
155+
def compute_softmax(attrs, inputs, out_info):
156+
"""Compute definition of softmax"""
157+
return topi.broadcast_to(inputs[0], shape=out_info[0].shape)
158+
reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
159+
reg.register_schedule("broadcast_to", _fschedule_broadcast)

nnvm/python/nnvm/top/transform.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# pylint: disable=invalid-name, unused-argument
2+
"""Tensor transformation ops"""
3+
from __future__ import absolute_import
4+
5+
import tvm
6+
from .tensor import _fschedule_broadcast
7+
from ..compiler import registry as reg
8+
from ..compiler import OpPattern
9+
10+
# Need add reshape, transpose
11+
12+
def _flatten_index(indices, shape):
13+
"""flatten the index to 1D"""
14+
idx = 0
15+
for i, value in enumerate(shape):
16+
if i != 0:
17+
idx *= value
18+
idx = idx + indices[i]
19+
return idx
20+
21+
# reshape
22+
@reg.register_compute("reshape")
23+
def compute_reshape(attrs, inputs, out_info):
24+
"""Compute definition of softmax"""
25+
# TODO(sxj) add support for general reshape
26+
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
27+
oshape = out_info[0].shape
28+
x = inputs[0]
29+
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
30+
reg.register_pattern("reshape", OpPattern.COMPLEX)
31+
reg.register_schedule("reshape", _fschedule_broadcast)

nnvm/src/compiler/graph_fuse.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
261261
if (inode.source->is_variable()) continue;
262262
int root_id = group_vec[nid];
263263
FuseEntry& fe = fuse_vec[root_id];
264-
Array<Tensor> inputs;
264+
Array<Tensor> inputs, out_info;
265265
// input loading
266266
for (const auto& e : inode.inputs) {
267267
if (group_vec[e.node_id] != root_id) {
@@ -274,11 +274,21 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
274274
inputs.push_back(t);
275275
}
276276
}
277+
// output hint
278+
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
279+
Array<Expr> shape;
280+
for (int64_t x : shape_vec[idx.entry_id(nid, i)]) {
281+
CHECK_LE(x, static_cast<int64_t>(std::numeric_limits<int>::max()));
282+
shape.push_back(make_const(Int(32), x));
283+
}
284+
out_info.push_back(
285+
placeholder(shape,
286+
TVMType2Type(dltype_vec[idx.entry_id(nid, i)])));
287+
}
277288
// get default
278289
Array<Tensor> out = fcompute[inode.source->op()](
279-
inode.source->attrs, inputs);
290+
inode.source->attrs, inputs, out_info);
280291
CHECK_EQ(out.size(), inode.source->num_outputs());
281-
282292
// schedule on root node, and use master's schedule
283293
if (nid != root_id) {
284294
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
@@ -312,6 +322,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
312322
}
313323
}
314324
}
325+
315326
tvm::runtime::Module module = fbuild(funcs, target);
316327
// Final step: Remap the node, with given attribute
317328
const nnvm::Op* tvm_op = nnvm::Op::Get("tvm_op");

0 commit comments

Comments
 (0)