Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,29 +779,32 @@ class Schedule : public runtime::ObjectRef {
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \param enable_check Whether to enable some prequisite checks for schedule primitives, it's
* user's duty to guarantee the schedule correctness if we disable the checks.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
* \note The checks performed includes: 1) VerifySRefTree 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check = true);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param seed The seed value for schedule's random state
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \param enable_check Whether to enable prequisite checks for schedule primitives.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed include:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check = true);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
9 changes: 8 additions & 1 deletion include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ enum ScheduleDebugMask : uint32_t {
* 3) The dependency information of each block scope (block_info)
* 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
* 5) A debug flag, if set, extra checking is enabled (debug_mask)
* 6) A check flag, if set, enable prequisite check for schedule primitives (enable_check)
*/
class ScheduleStateNode : public Object {
public:
Expand All @@ -100,12 +101,17 @@ class ScheduleStateNode : public Object {
* \sa ScheduleDebugMask
*/
int debug_mask;
/*!
* \brief Whether to enable prequisite checks for schedule primitives.
*/
bool enable_check;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mod", &mod);
// `block_info` is not visited
// `stmt2ref` is not visited
v->Visit("debug_mask", &debug_mask);
v->Visit("enable_check", &enable_check);
}
/*!
* \brief Replace the part of the AST, as being pointed to by `src_sref`,
Expand Down Expand Up @@ -194,8 +200,9 @@ class ScheduleState : public ObjectRef {
* \param mod The IRModule to be scheduled
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param enable_check Whether enables prerequisite checks for schedule primitives.
*/
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool enable_check = true);

/*! \return The mutable pointer to the ScheduleStateNode */
ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def _parse_error_render_level(error_render_level: str) -> int:
return _ERROR_RENDER_LEVEL.get(error_render_level)


def _parse_enable_checks(enable_checks: bool) -> bool:
if not isinstance(enable_checks, bool):
raise TypeError(
"enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
)
return enable_checks


def _parse_seed(seed: Optional[int]) -> int:
if seed is None:
return -1
Expand Down Expand Up @@ -114,6 +122,7 @@ def __init__(
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
enable_check: bool = True,
) -> None:
"""Construct a TensorIR schedule class from an IRModule

Expand All @@ -137,6 +146,15 @@ def __init__(
- "detail": Render a detailed error message, with the TIR and error locations printed
- "fast: Show a simple error message without rendering or string manipulation
- "none": Do not show any error message.
enable_check : bool = True
The default schedule checks are too strict and might prevent us performing some valid
schedules. `enable_check` is an argument to control whether we enable prerequisite
checks for some schedule primitives or not:
- true: perform prerequisite check before applying some schedules.
- false: do not perform some check before applying schedules, but still raise error
if schedule fails.

It's user duty to guarantee schedule correctness if `enable_check` is set to `False`.

Note
----
Expand All @@ -151,6 +169,7 @@ def __init__(
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
_parse_enable_checks(enable_check),
)

@staticmethod
Expand All @@ -160,13 +179,15 @@ def _create_non_traced(
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
enable_check: bool = True,
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
_parse_enable_checks(enable_check),
)

########## Utilities ##########
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/tir/schedule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
return debug_mask


def _parse_enable_checks(enable_checks: bool) -> bool:
if not isinstance(enable_checks, bool):
raise TypeError(
"enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
)
return enable_checks


@register_object("tir.ScheduleState")
class ScheduleState(Object):
"""The state of scheduling, which exposes a `Replace` method as
Expand All @@ -81,6 +89,7 @@ class ScheduleState(Object):
3) The dependency information of each block scope (block_info)
4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
5) A debug flag, if set, extra checking is enabled (debug_mask)
6) A enable check flag, if False, some prerequisite checks are disabled.

Parameters
----------
Expand All @@ -89,6 +98,9 @@ class ScheduleState(Object):
debug_mask : int
Do extra correctness checking after the object construction
and each time after calling the Replace method.
enable_check : bool
Indicates whether we enable prerequisite checks for some schedule primitives or not,
defaults to `True`.
"""

mod: IRModule
Expand All @@ -99,6 +111,7 @@ def __init__(
mod: Union[PrimFunc, IRModule],
*,
debug_mask: Union[str, int] = "none",
enable_check: bool = True,
) -> None:
"""Construct a schedule state from an IRModule or a PrimFunc

Expand All @@ -118,6 +131,7 @@ def __init__(
_ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_debug_mask(debug_mask),
_parse_enable_checks(enable_check),
)

def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Definition of a scope that is a stage pipeline:
}
}
// Step 2. Handle `require_stage_pipeline`
if (require_stage_pipeline) {
if (require_stage_pipeline && self->enable_check) {
bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
if (stage_pipeline == false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ namespace tvm {
namespace tir {

Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check) {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask);
n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
Expand Down Expand Up @@ -60,6 +61,7 @@ class ScheduleCopier {
n->block_info = copier.Copy(src_state->block_info);
n->stmt2ref = copier.Copy(src_state->stmt2ref);
n->debug_mask = src_state->debug_mask;
n->enable_check = src_state->enable_check;
*new_state = ScheduleState(std::move(n));
*new_symbol_table = copier.Copy(self->symbol_table_);
}
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
CheckSubtreeCompactDataflow(self, loop_sref);
if (self->enable_check) {
CheckSubtreeCompactDataflow(self, loop_sref);
}

// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
Expand Down
30 changes: 18 additions & 12 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
// Get the outer loops from high to low
Array<StmtSRef> loops = GetLoops(block_sref);
const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
// Cond 0. Check loop_sref is an ancestor of block_sref
if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
"decompose_reduction");
}
// Cond 1. Check block is reduction
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
CheckReductionBlock(self, block_sref, scope_root_sref);
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
if (self->enable_check) {
// Cond 0. Check loop_sref is an ancestor of block_sref
if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
"decompose_reduction");
}
// Cond 1. Check block is reduction
CheckReductionBlock(self, block_sref, scope_root_sref);
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
}
// IR Manipulation
ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
Expand Down Expand Up @@ -1176,7 +1178,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
const Block& block = block_realize->block;
StmtSRef scope_root = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/true);
CheckReductionBlock(self, block_sref, scope_root);
if (self->enable_check) {
CheckReductionBlock(self, block_sref, scope_root);
}
const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
if (rf_loop->kind != ForKind::kSerial) {
throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
Expand All @@ -1199,8 +1203,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
// - the outermost loop should have the reduction block as its first child block;
// - the outermost loop that is touched by some reduction block iters can only have one child
// block.
LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
reduce_loop_vars);
if (self->enable_check) {
LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
reduce_loop_vars);
}

// Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the
// commutative reducer, combiner lhs and combiner rhs from the reduction identity and the
Expand Down
10 changes: 6 additions & 4 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV
TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); });
TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
.set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, int error_render_level) -> Schedule {
int debug_mask, int error_render_level, bool enable_check) -> Schedule {
return Schedule::Concrete(mod, debug_mask, seed,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
static_cast<ScheduleErrorRenderLevel>(error_render_level),
enable_check);
});
TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
.set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, int error_render_level) -> Schedule {
int debug_mask, int error_render_level, bool enable_check) -> Schedule {
return Schedule::Traced(mod, seed, debug_mask,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
static_cast<ScheduleErrorRenderLevel>(error_render_level),
enable_check);
});

/******** (FFI) Lookup random variables ********/
Expand Down
23 changes: 16 additions & 7 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,21 @@ class BlockInfoCollector : private StmtVisitor {
class StateCreator : private StmtVisitor {
public:
/*!
* \brief The entry function
* \param self The schedule state to be completed
* \brief ScheduleState Creator
* \param mod The module being scheduled.
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param enable_check Whether to enable prequisite checks for schedule primitives.
*/
static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask, bool enable_check) {
ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
ScheduleStateNode* self = n.get();
// Set `n->mod`
n->mod = std::move(mod);
// Set `n->debug_mask`
n->debug_mask = debug_mask;
// Set `n->enable_check`
n->enable_check = enable_check;
// Set `n->stmt2ref` and `n->block_info`
StateCreator creator(self);
for (const auto& kv : n->mod->functions) {
Expand All @@ -426,6 +431,10 @@ class StateCreator : private StmtVisitor {
}

private:
/*!
* \brief The entry function
* \param self The schedule state to be completed
*/
explicit StateCreator(ScheduleStateNode* self) : self_(self) {}

/*!
Expand Down Expand Up @@ -481,9 +490,9 @@ class StateCreator : private StmtVisitor {

/**************** Constructor ****************/

ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) {
CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported";
data_ = StateCreator::Create(mod, debug_mask);
data_ = StateCreator::Create(mod, debug_mask, enable_check);
}

/**************** Replace ****************/
Expand Down Expand Up @@ -1108,8 +1117,8 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& bl

TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState")
.set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState {
return ScheduleState(mod, debug_mask);
.set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState {
return ScheduleState(mod, debug_mask, enable_check);
});
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
.set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ namespace tvm {
namespace tir {

Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check) {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask);
n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
Expand Down