Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ TVM_DLL Pass PointerValueTypeRewrite();
*/
TVM_DLL Pass HoistIfThenElse();

/*!
* \brief Hoist loop-invariant IfThenElse nodes to
* outside the elligible loops.
*
* \return The pass.
*/
TVM_DLL Pass HoistExpression();

/*!
* \brief Lower cross-thread reduction from thread
* bindings to intrinsic function calls.
Expand Down
70 changes: 70 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"""Wrapping existing transformations."""
# pylint: disable=invalid-name
from typing import Optional
import enum

from . import _ffi_api
from . import function_pass as _fpass

Expand Down Expand Up @@ -612,6 +614,74 @@ def HoistIfThenElse(variant: Optional[str] = None):
return _ffi_api.HoistIfThenElse() # type: ignore


class HoistedConditionals(enum.Flag):
"""Flags for use in HoistExpressionConfig.conditional_types

Each bitflag represents a type of expression that should be
hoisted to the outermost loop possible.
"""

Never = 0
""" No hoisting of conditionals """

IfElseStmt = 1
""" If set, look for hoist candidates in IfElseStmt """

IfElseExpr = 2
""" If set, look for hoist candidates in tir.if_then_else """

BooleanExpression = 4
""" If set, look for hoist candidates in all boolean expressions """

UsingBlockVar = 8
""" If set, allow hoisting of conditionals that use a block variable (e.g. threadIdx.x) """

All = IfElseStmt | IfElseExpr | BooleanExpression | UsingBlockVar
""" Enable all hoisting of conditionals"""


class HoistedLetBindings(enum.Flag):
"""Flags for use in HoistExpressionConfig.let_binding_types

Each bitflag represents a type of let binding expression that should be
hoisted to the outermost loop possible.
"""

Never = 0
""" No hoisting of let bindings """

RequiredByConditional = 1
""" Bindings that are used by a hoisted conditional """

LetStmt = 2
""" Bindings occuring in LetStmt """

LetExpr = 4
""" Bindings occuring in Let expressions """

All = RequiredByConditional | LetStmt | LetExpr
""" Enable all hoisting of let bindings """


def HoistExpression():
"""Generalized verison of HoistIfThenElse.

Hoist loop-invariant expressions to outside the eligible loops.
Searches for expressions in:

* LetStmt bindings
* IfThenElse conditions
* Boolean operators

Returns
-------
fpass : tvm.transform.Pass
The result pass

"""
return _ffi_api.HoistExpression() # type: ignore


def LowerCrossThreadReduction():
"""Lower cross-thread reduction from thread bindings to
intrinsic function calls.
Expand Down
8 changes: 7 additions & 1 deletion src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) {
size_t old_literal_size = literal_constraints_.size();
// we will compare the already simplified result with the constraint,
// so simplify the constarint as well
// so simplify the constraint as well
PrimExpr new_constraint = operator()(constraint);
for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint)) {
if (SideEffect(subconstraint) <= CallEffectKind::kPure) {
Expand Down Expand Up @@ -1652,6 +1652,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const VarNode* op) {
Var var = GetRef<Var>(op);
if (op->dtype == DataType::Bool()) {
if (auto match = TryMatchLiteralConstraint(var)) {
return match.value();
}
}

auto it = var_map_.find(var);
if (it != var_map_.end()) {
return it->second;
Expand Down
Loading