Skip to content

Commit ff8351a

Browse files
committed
[SCAN/Refactor] Refactor scan interface, enable fix point analysis.
1 parent 5198c10 commit ff8351a

File tree

20 files changed

+963
-373
lines changed

20 files changed

+963
-373
lines changed

include/tvm/operation.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
152152
/*!
153153
* \brief Construct new tensors by scan over scan_axis.
154154
*
155-
* \param scan_axis The iteration representing the scan.
156155
* \param init The intialize tensor of first K steps.
157156
* \param update The update tensor indicated the updated result after each timestamp.
158157
* \param state_placeholder The placeholder for the states.
159158
* \param name The optional name of the tensor.
160159
*/
161-
Array<Tensor> scan(IterVar scan_axis,
162-
Array<Tensor> init,
160+
Array<Tensor> scan(Array<Tensor> init,
163161
Array<Tensor> update,
164162
Array<Tensor> state_placeholder,
165163
std::string name = "scan");

include/tvm/schedule.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ enum AttachType : int {
2626
kNone = 0,
2727
kRoot = 1,
2828
kInline = 2,
29-
kScope = 3
29+
kInlinedAlready = 3,
30+
kScope = 4,
31+
kScanUpdate = 5
3032
};
3133

3234
/*! \brief IterVar type */

include/tvm/tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
175175
virtual Type output_dtype(size_t i) const = 0;
176176
/*! \return shape of i-th output */
177177
virtual Array<Expr> output_shape(size_t i) const = 0;
178+
179+
static constexpr const char* _type_key = "Operation";
178180
};
179181

180182
// Implementations of inline functions

python/tvm/api.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"):
140140
return op_node.output(0)
141141

142142

143-
def scan(axis, init, update, state_placeholder, name="scan"):
143+
def scan(init, update, state_placeholder, name="scan"):
144144
"""Construct new tensors by scanning over axis.
145145
146146
Parameters
147147
----------
148-
axis: IterVar
149-
The scanning axis.
150-
151148
init: Tensor or list of Tensor
152149
The initial condition of first init.shape[0] timestamps
153150
@@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"):
170167
# The following code is equivalent to numpy.cumsum
171168
m = tvm.Var("m")
172169
n = tvm.Var("n")
173-
t = tvm.IterVar((1, m), name="t")
174170
X = tvm.placeholder((m, n), name="X")
175171
s_state = tvm.placeholder((m, n))
176172
s_init = tvm.compute((1, n), lambda _, i: X[0, i])
177-
s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
178-
res = tvm.scan(t, s_init, s_update, s_state)
173+
s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
174+
res = tvm.scan(s_init, s_update, s_state)
179175
"""
180176
if isinstance(init, _tensor.Tensor):
181177
init = [init]
@@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"):
185181
state_placeholder = [state_placeholder]
186182
if len(init) != len(update) or len(init) != len(state_placeholder):
187183
raise ValueError("init, update, state_placeholder must have same length")
184+
axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name)
188185
op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
189186
res = [op.output(i) for i in range(len(update))]
190187
return (res[0] if len(res) == 1 else res)

python/tvm/build.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def build(sch,
6363
arg_list.append(x)
6464
else:
6565
raise ValueError("args must be Tensor, Buffer or Var")
66-
# lowering
66+
# normalize schedule first
67+
sch.normalize()
6768
bounds = schedule.InferBound(sch)
6869
stmt = schedule.ScheduleOps(sch, bounds)
6970
stmt = ir_pass.StorageFlatten(stmt, binds)

src/api/api_schedule.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
3434
REGISTER_SCHEDULE_PASS1(InferBound);
3535
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
3636
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
37+
REGISTER_SCHEDULE_PASS1(ScanGetBody);
38+
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
39+
REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
3740
REGISTER_SCHEDULE_PASS2(ScheduleOps);
3841

3942
} // namespace schedule

src/arithmetic/int_set.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,15 @@ IntSet Union(const Array<IntSet>& set) {
166166
if (set.size() == 1) return set[0];
167167
Interval x = set[0].cover_interval().as<IntervalSet>()->i;
168168
for (size_t i = 1; i < set.size(); ++i) {
169-
x.include(set[i].cover_interval().as<IntervalSet>()->i);
169+
IntSet s = set[i].cover_interval();
170+
const Interval& y = s.as<IntervalSet>()->i;
171+
if (can_prove(x.max + 1 >= y.min)) {
172+
x.max = y.max;
173+
} else if (can_prove(y.max + 1 >= x.min)) {
174+
x.min = y.min;
175+
} else {
176+
x.include(y);
177+
}
170178
}
171179
return IntervalSet::make(x);
172180
}

src/lang/operation.cc

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ Operation PlaceholderOpNode::make(std::string name,
5151
return Operation(n);
5252
}
5353

54-
55-
5654
Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
5755
return PlaceholderOpNode::make(name, shape, dtype).output(0);
5856
}
@@ -162,24 +160,25 @@ Operation ScanOpNode::make(std::string name,
162160
<< " scan_axis.dom.min + scan_axis.dom.extent";
163161
CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
164162
<< "The dimension of init need to match state_placeholder";
165-
CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
163+
CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
166164
<< "The update.ndim need to be state_placeholder.ndim - 1";
167165
for (size_t k = 0; k < update[i].ndim(); ++k) {
168166
CHECK(prove_equal(
169-
update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
170-
// setup spatial axis
171-
std::ostringstream spatial_name;
172-
spatial_name << name << ".out" << i << ".i" << k + 1;
173-
n->spatial_axis_.push_back(
174-
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
175-
spatial_name.str()));
167+
update[i]->shape[k], state_placeholder[i]->shape[k]));
168+
if (k != 0) {
169+
// setup spatial axis
170+
std::ostringstream spatial_name;
171+
spatial_name << name << ".out" << i << ".i" << k;
172+
n->spatial_axis_.push_back(
173+
IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
174+
spatial_name.str()));
175+
}
176176
}
177177
for (size_t k = 1; k < init[i].ndim(); ++k) {
178178
CHECK(prove_equal(
179179
init[i]->shape[k], state_placeholder[i]->shape[k]));
180180
}
181181
}
182-
183182
n->name = name;
184183
n->scan_axis = axis;
185184
n->init = init;
@@ -188,11 +187,14 @@ Operation ScanOpNode::make(std::string name,
188187
return Operation(n);
189188
}
190189

191-
Array<Tensor> scan(IterVar scan_axis,
192-
Array<Tensor> init,
190+
Array<Tensor> scan(Array<Tensor> init,
193191
Array<Tensor> update,
194192
Array<Tensor> state_placeholder,
195193
std::string name) {
194+
IterVar scan_axis(
195+
Range::make_with_min_extent(
196+
init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
197+
name + ".idx");
196198
Operation op = ScanOpNode::make(
197199
name, scan_axis, init, update, state_placeholder);
198200
Array<Tensor> res;

src/pass/inline.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ Stmt Inline(Stmt stmt,
6161
Expr body) {
6262
CHECK_EQ(f->num_outputs(), 1)
6363
<< "can only inline output single value operation";
64-
return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
64+
Stmt ret = IRInline(f, args, body).Mutate(stmt);
65+
if (ret.same_as(stmt)) return ret;
66+
return ConvertSSA(ret);
6567
}
6668
} // namespace ir
6769
} // namespace tvm

0 commit comments

Comments
 (0)