Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,9 +649,9 @@ class TVM_DLL Analyzer {
void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
* is created and bound to a range.
*
* Each var can only be binded once.
* Each var can only be bound once.
*
* \param var The variable.
* \param range The range we bind to.
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
# pylint: disable=invalid-name
"""Arithmetic data structure and utility"""
from enum import IntEnum
from typing import Union

import tvm._ffi
from tvm import tir, ir
from tvm.runtime import Object

from . import _ffi_api


Expand Down Expand Up @@ -227,16 +231,16 @@ def can_prove(self, expr, strength=ProofStrength.DEFAULT):
"""
return self._can_prove(expr, strength)

def bind(self, var, expr):
def bind(self, var: tir.Var, expr: Union[tir.PrimExpr, ir.Range]):
"""Bind a variable to the expression.

Parameters
----------
var : tvm.tir.Var
The variable.

expr : PrimExpr
The expression.
expr : Union[tir.PrimExpr, ir.Range]
The expression or the range to bind to.
"""
return self._bind(var, expr)

Expand Down
58 changes: 54 additions & 4 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ class SumExprNode : public CanonicalExprNode {
if (lhs->div_mode < rhs->div_mode) return false;
// tie.
// TODO(tvm-team) We might consider index as the last comparison point,
// after we make deep comparator more derministic.
// after we make deep comparator more deterministic.
// Specifically, we can consider comparing names of vars and break ties with address.
return false;
};
Expand Down Expand Up @@ -607,6 +607,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
PrimExpr VisitExpr_(const FloorModNode* op) final;
PrimExpr VisitExpr_(const ReduceNode* op) final;
PrimExpr VisitExpr_(const CastNode* op) final;
PrimExpr VisitExpr_(const LTNode* op) final;

private:
/*!
Expand Down Expand Up @@ -636,7 +637,7 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
SumExpr* out_non_divisible);
/*!
* \brief Pattern match and check whether lhs is fully divisible by
* rhs using prod pattern simiplification expressions.
* rhs using prod pattern simplification expressions.
*
* The following two relations holds for floordiv/mod and truncdiv/mod
* Note that the relation do not hold for euclidean divide and mod.
Expand Down Expand Up @@ -1158,7 +1159,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
if (TryCompare(temp, cval) == CompareResult::kLT) {
return temp;
} else {
// contonue to use logic below.
// continue to use logic below.
a = extra;
psum = a.as<SumExprNode>();
ICHECK(psum != nullptr);
Expand Down Expand Up @@ -1227,7 +1228,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
analyzer_->CanProveGreaterEqual(temp, 0)) {
return temp;
} else {
// contonue to use logic below.
// continue to use logic below.
a = extra;
psum = a.as<SumExprNode>();
ICHECK(psum != nullptr);
Expand Down Expand Up @@ -1386,6 +1387,55 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const CastNode* op) {
return Rewriter::VisitExpr_(op);
}

PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const LTNode* op) {
// First convert a < b into a - b < 0
PrimExpr expr = this->CanonicalMutate(op->a - op->b);
// Case: x0 * s0 + x1 * s1 + ... + xn + c < 0, let d = gcd(s0, s1, ..., s{n-1}, c)
// 1. if can prove -d < xn < d, then we can simplify
// the expression to x0 * (s0/d) + x1 * (s1/d) + ... + x{n-1} * (s{n-1}/d) < c/d,
// e.g. `x * 8 + y < 16` where `y` \in [0, 8), we can simplify it to `x < 2`
// 2. if xn is in pattern of yn % m, where m % d == 0, convert it to yn // d % (m/d)
// e.g. `x1 * 64 + (x2 * 8 + x3) % 64 < 120`, `x3` \in [0, 8), we can simplify it to
// `x1 * 8 + (x2 * 8 + x3) // 8 % 8 < 15` ==> `x1 * 8 + x2 % 8 < 15`

if (const auto* lhs = expr.as<SumExprNode>()) {
int64_t gcd = lhs->base;
bool has_non_one_scale = false;
for (const SplitExpr& split_expr : lhs->args) {
if (split_expr->scale > 1 || split_expr->scale < -1) {
has_non_one_scale = true;
gcd = ZeroAwareGCD(gcd, std::abs(split_expr->scale));
}
}
// Skip if gcd == 1 or all s_n are 1
if (!has_non_one_scale || gcd <= 1) {
return Rewriter::VisitExpr_(op);
}
SumExpr divisible, extra;
SeparateDivisibleParts(lhs, gcd, &divisible, &extra);
DataType dtype = divisible->dtype;
ICHECK(extra->dtype == dtype);
PrimExpr normal_extra = extra->Normalize();
if (this->analyzer_->CanProve(normal_extra < make_const(dtype, gcd)) &&
this->analyzer_->CanProve(normal_extra > make_const(dtype, -gcd))) {
// Case 1. -d < xn < d
divisible.CopyOnWrite()->DivideBy(gcd);
return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
} else if (extra->args.size() == 1 &&
extra->args[0]->upper_factor % (gcd * extra->args[0]->lower_factor) == 0) {
// Case 2. xn == yn % m, where m % d == 0
divisible.CopyOnWrite()->DivideBy(gcd);
const auto split_expr = extra->args[0];
int64_t lower_factor = gcd * extra->args[0]->lower_factor;
PrimExpr extra_expr = floormod(floordiv(split_expr->index, lower_factor),
floordiv(split_expr->upper_factor, lower_factor));
return Rewriter::VisitExpr(divisible->Normalize() + extra_expr < make_zero(dtype));
}
}

return Rewriter::VisitExpr_(op);
}

PrimExpr CanonicalSimplifier::operator()(const PrimExpr& expr) {
return impl_->CanonicalSimplify(expr);
}
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_arith_canonical_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,5 +422,45 @@ def test_floormod_two():
ck.verify(flm(x * 10 + 1 + y * 2 + 2, 2), 1)


def test_simplify_le():
ck = CanonicalChecker()
# Case 1. Ignore the extra expr if it's small than the division number
x, y, z = te.var("x"), te.var("y"), te.var("z")
ck.analyzer.bind(y, tvm.ir.Range(0, 8))
ck.analyzer.bind(z, tvm.ir.Range(0, 2))
ck.verify(x * 8 + y < 16, x < 2)
ck.verify(x * 8 + z * 4 < 16, x < 2)
ck.verify(x * 8 + z * 4 < 16, x < 2)

# TODO: Not sure why `-2 < x` will be convert to `x > -2`, use a explicit simplify here.
ck.verify(x * -8 + y < 16, ck.analyzer.rewrite_simplify(-2 < x))
ck.verify(x * -8 + z * 4 < 16, ck.analyzer.rewrite_simplify(-2 < x))

ck.verify(x * 8 + y + z < 16, x * 8 + y + z < 16)
ck.verify(x * 8 + y - z < 16, x < 2)

n = te.size_var("n")
ck.verify(x * 8 + y < n, x * 8 + y < n)

# Case 2. Simplify the extra expr
x1, x2, ty, tx, vec = (
tvm.te.var("x1"),
tvm.te.var("x2"),
tvm.te.var("ty"),
tvm.te.var("tx"),
tvm.te.var("vec"),
)
ck.analyzer.bind(x1, tvm.ir.Range(0, 2))
ck.analyzer.bind(x2, tvm.ir.Range(0, 3))
ck.analyzer.bind(ty, tvm.ir.Range(0, 8))
ck.analyzer.bind(tx, tvm.ir.Range(0, 32))
ck.analyzer.bind(vec, tvm.ir.Range(0, 8))
ck.verify(
x1 * 5632 + (((x2 * 8 + ty) * 32 + tx) * 8 + vec) % 5632 < 11008,
x1 * 22 + (x2 * 8 + ty) % 22 < 43,
)
ck.verify(tx // 2 % 8 + vec < 8, tx % 16 // 2 + vec < 8)


if __name__ == "__main__":
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def Move_PUV0(a: T.handle, b: T.handle) -> None:
vj = T.axis.spatial(1024, i0_j0_fused % 64 * 32 + j1 * 8 + j2)
vk = T.axis.spatial(1024, k0 * 32 + k1_fused)
T.where(
i0_j0_fused // 64 * 16 + i1 * 4 + i2 < 1024
and i0_j0_fused % 64 * 32 + j1 * 8 + j2 < 1024
and k0 * 32 + k1_fused < 1024
i0_j0_fused < 4064
and i0_j0_fused % 64 < 32
and k0 < 32
)
T.reads([A[vi, vj, vk]])
T.writes([B[vi, vj, vk]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu
ci = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax0)
p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax1)
eps, nu = T.axis.remap("SS", [ax2, ax3])
T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200)
T.where(i2_i3_fused_0 * 256 + i2_i3_fused_1 < 3800)
T.reads(p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1])
T.writes(input_tile_local[ci, p, eps, nu])
T.block_attr({"schedule_rule": "None"})
Expand All @@ -484,7 +484,7 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu
v0, v1 = T.axis.remap("SS", [ax0, ax1])
v2 = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax2)
v3 = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax3)
T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200)
T.where(i2_i3_fused_0 * 256 + i2_i3_fused_1 < 3800)
T.reads(data_pack_local[v0, v1, v2, v3])
T.writes(data_pack[v0, v1, v2, v3])
data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3]
Expand Down
3 changes: 2 additions & 1 deletion tests/python/unittest/test_target_codegen_c_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_add_pipeline():
s[C].pragma(xo1, "parallel_launch_point")
s[C].pragma(xo2, "parallel_stride_pattern")
s[C].pragma(xo2, "parallel_barrier_when_finish")
s[C].vectorize(xi)
# FIXME(tvm-team): vector operators are not supported for codegen to C yet
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is because of the poor support of vectorization for c_codegen.
Before this PR, TVM failed to simplify the branch predicate and skip the vectorization. This PR enhances the arith, and makes it TRUE vectoring, but failed on codegen stage.

In short:

  1. It's not related to this PR, as it's codegen issue
  2. It's not a regression, vectorized step is skipped before this PR.

# s[C].vectorize(xi)

def check_c():
# Specifically allow offset to test codepath when offset is available
Expand Down