Skip to content

Commit a69dba8

Browse files
Lunderbergblackkker
authored andcommitted
[Arith] Simplification of ceil, log2, and left_shift (apache#11646)
* [TIR] Simplify expressions using tir.ceil and tir.log2 These expressions are introduced in `topi.math.ceil_log2`, and can otherwise be propagated through to the generated kernel. * [Arith] Added left shift handling to ConstIntBoundsAnalyzer Previously, only right shift was handled. These left shifts are used in the `cuda.sort` implementation. * Update to avoid left shift of negative numbers * Updated rewriting of log2(x) to only occur in ceil(log2(x)) Per @wrongtest's request, to avoid rounding differences between different devices. * Avoid assumptions made of negative arguments to left-shift * Recognize bounds of int(ceil(log2(arg)))
1 parent aa5b8c6 commit a69dba8

File tree

3 files changed

+222
-2
lines changed

3 files changed

+222
-2
lines changed

src/arith/const_int_bound.cc

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,17 @@ class ConstIntBoundAnalyzer::Impl
177177
}
178178

179179
Entry VisitExpr_(const CastNode* op) final {
180-
Entry a = VisitExpr(op->value);
180+
Entry a;
181+
182+
// int(ceil(log2(cast(n,"float64")))) is used as the
183+
// implementation of topi.math.ceil_log2, and appears in iteration
184+
// bounds.
185+
if (auto opt = FindCeilLog2Arg(op)) {
186+
a = CeilLog2Bounds(opt.value());
187+
} else {
188+
a = VisitExpr(op->value);
189+
}
190+
181191
Entry b = Everything(op->dtype);
182192
return Intersect(a, b);
183193
}
@@ -314,6 +324,8 @@ class ConstIntBoundAnalyzer::Impl
314324

315325
if (op->op.same_as(tir::builtin::shift_right())) {
316326
return VisitRightShift(op);
327+
} else if (op->op.same_as(tir::builtin::shift_left())) {
328+
return VisitLeftShift(op);
317329
} else if (op->op.same_as(tir::builtin::bitwise_and())) {
318330
return VisitBitwiseAnd(op);
319331
} else {
@@ -341,6 +353,20 @@ class ConstIntBoundAnalyzer::Impl
341353
}
342354
}
343355

356+
Entry VisitLeftShift(const CallNode* op) {
357+
Entry a = VisitExpr(op->args[0]);
358+
Entry b = VisitExpr(op->args[1]);
359+
360+
if (a.min_value < 0 || b.min_value < 0) {
361+
// If either operand can negative, we may run into undefined
362+
// behavior for some targets. In these cases, avoid making any
363+
// assumptions about the result.
364+
return Everything(op->dtype);
365+
}
366+
367+
return BinaryOpBoundary(a, b, InfAwareLeftShift);
368+
}
369+
344370
Entry VisitRightShift(const CallNode* op) {
345371
Entry a = VisitExpr(op->args[0]);
346372
Entry b = VisitExpr(op->args[1]);
@@ -509,7 +535,33 @@ class ConstIntBoundAnalyzer::Impl
509535
return floordiv(x, y);
510536
}
511537
/*!
512-
* \brief Compute x / y, aware of inf.
538+
* \brief Compute x << y, aware of inf.
539+
* \param x The left operand.
540+
* \param y The right operand.
541+
* \return the result.
542+
*/
543+
static int64_t InfAwareLeftShift(int64_t x, int64_t y) {
544+
if (x == kPosInf || x == kNegInf) return x;
545+
546+
// Can be replaced with std::bit_width in C++20
547+
auto bit_width = [](int64_t as_signed) {
548+
uint64_t val = std::abs(as_signed);
549+
int num_bits = 0;
550+
while (val) {
551+
++num_bits;
552+
val >>= 1;
553+
}
554+
return num_bits;
555+
};
556+
int x_bits = bit_width(x);
557+
if (x_bits + y < 64) {
558+
return x << y;
559+
} else {
560+
return kPosInf;
561+
}
562+
}
563+
/*!
564+
* \brief Compute x >> y, aware of inf.
513565
* \param x The left operand.
514566
* \param y The right operand.
515567
* \return the result.
@@ -609,6 +661,46 @@ class ConstIntBoundAnalyzer::Impl
609661
}
610662
return {};
611663
}
664+
665+
/*!
666+
* \brief Extract the argument from int(ceil(log2(arg)))
667+
*
668+
* This expression is used as the implementation of
669+
* topi.math.ceil_log2, and can appear in iteration bounds.
670+
*/
671+
static Optional<PrimExpr> FindCeilLog2Arg(const CastNode* op) {
672+
if (op->dtype.is_int()) {
673+
if (auto as_call = op->value.as<CallNode>()) {
674+
if (as_call->op.same_as(Op::Get("tir.ceil"))) {
675+
PrimExpr ceil_arg = as_call->args[0];
676+
if (auto arg_call = ceil_arg.as<CallNode>()) {
677+
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
678+
PrimExpr log_arg = arg_call->args[0];
679+
return log_arg;
680+
}
681+
}
682+
}
683+
}
684+
}
685+
return NullOpt;
686+
}
687+
688+
/*! \brief Propagate constraints through ceil(log2(arg))
689+
*
690+
* Helper function for CastNode visitor
691+
*/
692+
Entry CeilLog2Bounds(PrimExpr arg) {
693+
if (auto as_float = arg.as<FloatImmNode>()) {
694+
// A cast from int to float may have already been simplified
695+
// out. Normally we don't inspect floating-point arguments, but here we can
696+
int64_t val = std::ceil(std::log2(as_float->value));
697+
return MakeBound(val, val);
698+
} else {
699+
Entry arg_bounds = VisitExpr(arg);
700+
return MakeBound(std::ceil(std::log2(arg_bounds.min_value)),
701+
std::ceil(std::log2(arg_bounds.max_value)));
702+
}
703+
}
612704
};
613705

614706
ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {

src/arith/rewrite_simplify.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,13 +1640,34 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
16401640
// the operator overload will eagerly constant fold.
16411641
return op->args[0] << op->args[1];
16421642
}
1643+
} else if (op->op.same_as(Op::Get("tir.ceil"))) {
1644+
PrimExpr ceil_arg = op->args[0];
1645+
if (auto arg_int = op->args[0].as<IntImmNode>()) {
1646+
return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value));
1647+
} else if (auto arg_float = ceil_arg.as<FloatImmNode>()) {
1648+
return cast(op->dtype, FloatImm(arg_float->dtype, std::ceil(arg_float->value)));
1649+
} else if (auto arg_call = ceil_arg.as<CallNode>()) {
1650+
// ceil(log2(cast(n,"float64"))) is used as the implementation of
1651+
// topi.math.ceil_log2, and appears in iteration bounds.
1652+
if (arg_call->op.same_as(Op::Get("tir.log2"))) {
1653+
PrimExpr log_arg = arg_call->args[0];
1654+
if (auto as_float = log_arg.as<FloatImmNode>()) {
1655+
// ceil(log2(n)) can be simplified, and should produce the
1656+
// same integer result regardless of the target's rounding
1657+
// conventions.
1658+
return FloatImm(op->dtype, std::ceil(std::log2(as_float->value)));
1659+
}
1660+
}
1661+
}
16431662
}
1663+
16441664
if (op->op.same_as(tir::builtin::likely())) {
16451665
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
16461666
if (auto match = TryMatchLiteralConstraint(op->args[0])) {
16471667
return match.value();
16481668
}
16491669
}
1670+
16501671
return ret;
16511672
}
16521673

tests/python/unittest/test_tir_transform_simplify.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,5 +391,112 @@ def expected(A: T.Buffer[(16, 16), "int32"], n: T.int32):
391391
A[i, j] = 2
392392

393393

394+
class TestCeilLog2Int(BaseBeforeAfter):
395+
"""Simplify expressions resulting from topi.math.ceil_log2"""
396+
397+
@T.prim_func
398+
def before(A: T.Buffer[1, "int32"]):
399+
A[0] = T.cast(
400+
T.ceil(T.log2(T.cast(14, "float64"), dtype="float64"), dtype="float64"), dtype="int32"
401+
)
402+
403+
@T.prim_func
404+
def expected(A: T.Buffer[1, "int32"]):
405+
A[0] = 4
406+
407+
408+
class TestLeftCeilLog2LowerBound(BaseBeforeAfter):
409+
"""Integer bounds are propagated through topi.math.ceil_log2"""
410+
411+
@T.prim_func
412+
def before(A: T.Buffer[16, "float32"]):
413+
for i in T.serial(16):
414+
x = T.cast(
415+
T.ceil(T.log2(T.cast(i + 1024 + 1, "float64"), dtype="float64"), dtype="float64"),
416+
dtype="int32",
417+
)
418+
if x == 11:
419+
A[i] = 0.0
420+
421+
@T.prim_func
422+
def expected(A: T.Buffer[16, "float32"]):
423+
for i in T.serial(16):
424+
A[i] = 0.0
425+
426+
427+
class TestLeftShiftLowerBound(BaseBeforeAfter):
428+
"""Integer bounds are propagated through left shift
429+
430+
min(1 << i) = 1 << min(i)
431+
= 1 << 0
432+
= 1
433+
"""
434+
435+
@T.prim_func
436+
def before(A: T.Buffer[16, "float32"]):
437+
for i in T.serial(16):
438+
if T.shift_left(1, i, dtype="int32") >= 1:
439+
A[i] = 0.0
440+
441+
@T.prim_func
442+
def expected(A: T.Buffer[16, "float32"]):
443+
for i in T.serial(16):
444+
A[i] = 0.0
445+
446+
447+
class TestLeftShiftUpperBound(BaseBeforeAfter):
448+
"""Integer bounds are propagated through left shift
449+
450+
max(31 << i) = 31 << max(i)
451+
= 31 << 15
452+
= 1015808
453+
"""
454+
455+
@T.prim_func
456+
def before(A: T.Buffer[16, "float32"]):
457+
for i in T.serial(16):
458+
if T.shift_left(31, i, dtype="int32") <= 1015808:
459+
A[i] = 0.0
460+
461+
@T.prim_func
462+
def expected(A: T.Buffer[16, "float32"]):
463+
for i in T.serial(16):
464+
A[i] = 0.0
465+
466+
467+
class TestLeftShiftOfNegativeValue(BaseBeforeAfter):
468+
"""No const int bounds of left shift of negative value.
469+
470+
This is target dependent, and does not currently have a specified
471+
behavior in TIR. For example, in CodeGenC, this generates C code
472+
with undefined behavior.
473+
"""
474+
475+
@T.prim_func
476+
def before(A: T.Buffer[16, "float32"]):
477+
for i in T.serial(16):
478+
if -64 <= T.shift_left(-i, 4, dtype="int32"):
479+
A[i] = 0.0
480+
481+
expected = before
482+
483+
484+
class TestLeftShiftByNegativeValue(BaseBeforeAfter):
485+
"""No const int bounds of left shift by negative bit count.
486+
487+
This is target dependent, and does not currently have a specified
488+
behavior in TIR. For example, in CodeGenC, this generates C code
489+
with undefined behavior.
490+
"""
491+
492+
@T.prim_func
493+
def before(A: T.Buffer[16, "float32"]):
494+
for i in T.serial(16):
495+
if T.shift_left(16, -i, dtype="int32") <= 16:
496+
A[i] = 0.0
497+
498+
expected = before
499+
500+
394501
if __name__ == "__main__":
395502
tvm.testing.main()

0 commit comments

Comments
 (0)