Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,29 +779,33 @@ class Schedule : public runtime::ObjectRef {
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \param enable_check Whether to enable prequisite checks for schedule primitives.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check = true);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param seed The seed value for schedule's random state
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
* \param enable_check Whether to enable prequisite checks for schedule primitives.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed include:
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check = true);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
9 changes: 8 additions & 1 deletion include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ enum ScheduleDebugMask : uint32_t {
* 3) The dependency information of each block scope (block_info)
* 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
* 5) A debug flag, if set, extra checking is enabled (debug_mask)
* 6) A check flag, if set, enable prequisite check for schedule primitives (enable_check)
*/
class ScheduleStateNode : public Object {
public:
Expand All @@ -100,12 +101,17 @@ class ScheduleStateNode : public Object {
* \sa ScheduleDebugMask
*/
int debug_mask;
/*!
* \brief Whether to enable prequisite checks for schedule primitives.
*/
bool enable_check;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mod", &mod);
// `block_info` is not visited
// `stmt2ref` is not visited
v->Visit("debug_mask", &debug_mask);
v->Visit("enable_check", &enable_check);
}
/*!
* \brief Replace the part of the AST, as being pointed to by `src_sref`,
Expand Down Expand Up @@ -194,8 +200,9 @@ class ScheduleState : public ObjectRef {
* \param mod The IRModule to be scheduled
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param enable_check Whether enables prerequisite checks for schedule primitives.
*/
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool enable_check = true);

/*! \return The mutable pointer to the ScheduleStateNode */
ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def _parse_error_render_level(error_render_level: str) -> int:
return _ERROR_RENDER_LEVEL.get(error_render_level)


def _parse_enable_checks(enable_checks: bool) -> bool:
if not isinstance(enable_checks, bool):
raise TypeError(
"enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
)
return enable_checks


def _parse_seed(seed: Optional[int]) -> int:
if seed is None:
return -1
Expand Down Expand Up @@ -114,6 +122,7 @@ def __init__(
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
enable_check: bool = True,
) -> None:
"""Construct a TensorIR schedule class from an IRModule

Expand All @@ -137,6 +146,10 @@ def __init__(
- "detail": Render a detailed error message, with the TIR and error locations printed
- "fast: Show a simple error message without rendering or string manipulation
- "none": Do not show any error message.
enable_check : bool = True
Whether we enable prerequisite checks for schedule primitives or not
- true: perform prerequisite check before applying schedules.
- false: do not perform check before applying schedules, but still raise error if schedule fails.

Note
----
Expand All @@ -151,6 +164,7 @@ def __init__(
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
_parse_enable_checks(enable_check),
)

@staticmethod
Expand All @@ -160,13 +174,15 @@ def _create_non_traced(
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
enable_check: bool = True,
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
_parse_enable_checks(enable_check),
)

########## Utilities ##########
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Definition of a scope that is a stage pipeline:
}
}
// Step 2. Handle `require_stage_pipeline`
if (require_stage_pipeline) {
if (require_stage_pipeline && self->enable_check) {
bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
if (stage_pipeline == false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ namespace tvm {
namespace tir {

Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check) {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask);
n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
Expand Down Expand Up @@ -60,6 +61,7 @@ class ScheduleCopier {
n->block_info = copier.Copy(src_state->block_info);
n->stmt2ref = copier.Copy(src_state->stmt2ref);
n->debug_mask = src_state->debug_mask;
n->enable_check = src_state->enable_check;
*new_state = ScheduleState(std::move(n));
*new_symbol_table = copier.Copy(self->symbol_table_);
}
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/primitive/for_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
* parallelized/vectorized/bound.
*/
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
CheckSubtreeCompactDataflow(self, loop_sref);
if (self->enable_check) {
CheckSubtreeCompactDataflow(self, loop_sref);
}

// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
// underlying block.
Expand Down
30 changes: 18 additions & 12 deletions src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,17 +188,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
// Get the outer loops from high to low
Array<StmtSRef> loops = GetLoops(block_sref);
const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
// Cond 0. Check loop_sref is an ancestor of block_sref
if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
"decompose_reduction");
}
// Cond 1. Check block is reduction
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
/*require_stage_pipeline=*/false);
CheckReductionBlock(self, block_sref, scope_root_sref);
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
if (self->enable_check) {
// Cond 0. Check loop_sref is an ancestor of block_sref
if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
"decompose_reduction");
}
// Cond 1. Check block is reduction
CheckReductionBlock(self, block_sref, scope_root_sref);
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
}
// IR Manipulation
ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
Expand Down Expand Up @@ -1176,7 +1178,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
const Block& block = block_realize->block;
StmtSRef scope_root = GetScopeRoot(self, block_sref, //
/*require_stage_pipeline=*/true);
CheckReductionBlock(self, block_sref, scope_root);
if (self->enable_check) {
CheckReductionBlock(self, block_sref, scope_root);
}
const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
if (rf_loop->kind != ForKind::kSerial) {
throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
Expand All @@ -1199,8 +1203,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
// - the outermost loop should have the reduction block as its first child block;
// - the outermost loop that is touched by some reduction block iters can only have one child
// block.
LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
reduce_loop_vars);
if (self->enable_check) {
LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
reduce_loop_vars);
}

// Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the
// commutative reducer, combiner lhs and combiner rhs from the reduction identity and the
Expand Down
10 changes: 6 additions & 4 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV
TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); });
TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
.set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, int error_render_level) -> Schedule {
int debug_mask, int error_render_level, bool enable_check) -> Schedule {
return Schedule::Concrete(mod, debug_mask, seed,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
static_cast<ScheduleErrorRenderLevel>(error_render_level),
enable_check);
});
TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
.set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, int error_render_level) -> Schedule {
int debug_mask, int error_render_level, bool enable_check) -> Schedule {
return Schedule::Traced(mod, seed, debug_mask,
static_cast<ScheduleErrorRenderLevel>(error_render_level));
static_cast<ScheduleErrorRenderLevel>(error_render_level),
enable_check);
});

/******** (FFI) Lookup random variables ********/
Expand Down
23 changes: 16 additions & 7 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -402,16 +402,21 @@ class BlockInfoCollector : private StmtVisitor {
class StateCreator : private StmtVisitor {
public:
/*!
* \brief The entry function
* \param self The schedule state to be completed
* \brief ScheduleState Creator
* \param mod The module being scheduled.
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param enable_check Whether to enable prequisite checks for schedule primitives.
*/
static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask, bool enable_check) {
ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
ScheduleStateNode* self = n.get();
// Set `n->mod`
n->mod = std::move(mod);
// Set `n->debug_mask`
n->debug_mask = debug_mask;
// Set `n->enable_check`
n->enable_check = enable_check;
// Set `n->stmt2ref` and `n->block_info`
StateCreator creator(self);
for (const auto& kv : n->mod->functions) {
Expand All @@ -426,6 +431,10 @@ class StateCreator : private StmtVisitor {
}

private:
/*!
* \brief The entry function
* \param self The schedule state to be completed
*/
explicit StateCreator(ScheduleStateNode* self) : self_(self) {}

/*!
Expand Down Expand Up @@ -481,9 +490,9 @@ class StateCreator : private StmtVisitor {

/**************** Constructor ****************/

ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) {
CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported";
data_ = StateCreator::Create(mod, debug_mask);
data_ = StateCreator::Create(mod, debug_mask, enable_check);
}

/**************** Replace ****************/
Expand Down Expand Up @@ -1108,8 +1117,8 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& bl

TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState")
.set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState {
return ScheduleState(mod, debug_mask);
.set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState {
return ScheduleState(mod, debug_mask, enable_check);
});
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
.set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ namespace tvm {
namespace tir {

Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
int debug_mask, ScheduleErrorRenderLevel error_render_level,
bool enable_check) {
ObjectPtr<TracedScheduleNode> n = make_object<TracedScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask);
n->state_ = ScheduleState(mod, debug_mask, enable_check);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
Expand Down