Skip to content

Commit

Permalink
[BugFix][TensorIR] Non-positive constant input factors for split (#…
Browse files Browse the repository at this point in the history
…9805)

* Update docs of GetProducers/GetConsumers

* Fix split for non-positive factors
  • Loading branch information
MasterJH5574 authored Dec 27, 2021
1 parent 218d291 commit 2c654b5
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
14 changes: 8 additions & 6 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,17 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) = 0;
/*!
* \brief Get the producer of a specific block
* \brief Get the producer of a specific block, under the same block scope
* \param block_rv The block in the query
* \return A list of blocks, the producers of the given block
* \return A list of blocks, the producers of the given block under the same scope of the given
* block
*/
virtual Array<BlockRV> GetProducers(const BlockRV& block_rv) = 0;
/*!
* \brief Get the consumers of a specific block
* \brief Get the consumers of a specific block, under the same block scope
* \param block_rv The block to be queried
* \return A list of blocks, the consumers of the given block
* \return A list of blocks, the consumers of the given block under the same scope of the given
* block
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/******** Schedule: Transform loops ********/
Expand All @@ -266,8 +268,8 @@ class ScheduleNode : public runtime::Object {
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The tiling factors, and at most one of which is -1, which means that
* factor is inferred.
* \param factors The positive tiling factors, and at most one of which is `NullOpt`, which means
* that factor is inferred.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def split(
Potential inputs are:
- None
- ExprRV
- Non-negative constant integers
- Positive constant integers
Returns
-------
Expand Down
33 changes: 30 additions & 3 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,31 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
IRModule mod_;
For loop_;
};

class NonPositiveFactorError : public ScheduleError {
public:
explicit NonPositiveFactorError(IRModule mod, int64_t factor, size_t idx)
: mod_(std::move(mod)), factor_(factor), idx_(idx) {}

String FastErrorString() const final {
return "ScheduleError: All the constant factors are required to be positive. However, some "
"constant input factor is zero or negative.";
}
String DetailRenderTemplate() const final {
std::ostringstream os;
os << "All the constant factors are required to be positive. However, the factor at position "
<< idx_ << " is " << factor_;
return os.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

private:
IRModule mod_;
int64_t factor_;
size_t idx_;
};

// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Expand All @@ -389,13 +414,15 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
for (size_t i = 0; i < factor_rvs.size(); i++) {
if (!factor_rvs[i].defined()) {
factors.push_back(Integer(-1));
if (infer_index == -1) {
infer_index = i;
} else {
if (infer_index != -1) {
throw NotSingleInferFactorError(state_->mod);
}
infer_index = i;
} else {
PrimExpr factor = this->Get(factor_rvs[i].value());
if (is_const_int(factor) && !is_positive_const(factor)) {
throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
}
factors.push_back(factor);
tot_length *= factor;
}
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,18 @@ def test_split_with_opaque_access():
verify_trace_roundtrip(sch=sch, mod=opaque_access)


def test_split_with_non_positive_factors():
sch = tir.Schedule(elementwise, debug_mask="all")
block_b = sch.get_block("B")
i, j, k = sch.get_loops(block_b)
with pytest.raises(tvm.tir.ScheduleError):
sch.split(i, factors=[-2, -64])
with pytest.raises(tvm.tir.ScheduleError):
sch.split(j, factors=[0, None])
with pytest.raises(tvm.tir.ScheduleError):
sch.split(k, factors=[None, -16])


def test_fuse_split_fail_with_thread_binding():
sch = tir.Schedule(elementwise_with_thread_binding, debug_mask="all")
block_b = sch.get_block("B")
Expand Down

0 comments on commit 2c654b5

Please sign in to comment.