diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 22febfdfedec..ef5fd637d3cd 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -779,14 +779,15 @@ class Schedule : public runtime::ObjectRef { * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering + * \param enable_check Whether to enable some prequisite checks for schedule primitives, it's + * user's duty to guarantee the schedule correctness if we disable the checks. * \return The concrete schedule created * \sa ScheduleDebugMask - * \note The checks performed includes: - * 1) VerifySRefTree - * 2) VerifyCachedFlags + * \note The checks performed includes: 1) VerifySRefTree 2) VerifyCachedFlags */ TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, ScheduleErrorRenderLevel error_render_level); + int debug_mask, ScheduleErrorRenderLevel error_render_level, + bool enable_check = true); /*! * \brief Construct a traced concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled @@ -794,6 +795,7 @@ class Schedule : public runtime::ObjectRef { * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering + * \param enable_check Whether to enable prequisite checks for schedule primitives. * \return The concrete schedule created * \sa ScheduleDebugMask * \note The checks performed include: @@ -801,7 +803,8 @@ class Schedule : public runtime::ObjectRef { * 2) VerifyCachedFlags */ TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, ScheduleErrorRenderLevel error_render_level); + int debug_mask, ScheduleErrorRenderLevel error_render_level, + bool enable_check = true); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 201d78fe631c..a089de279946 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -81,6 +81,7 @@ enum ScheduleDebugMask : uint32_t { * 3) The dependency information of each block scope (block_info) * 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref) * 5) A debug flag, if set, extra checking is enabled (debug_mask) + * 6) A check flag, if set, enable prequisite check for schedule primitives (enable_check) */ class ScheduleStateNode : public Object { public: @@ -100,12 +101,17 @@ class ScheduleStateNode : public Object { * \sa ScheduleDebugMask */ int debug_mask; + /*! + * \brief Whether to enable prequisite checks for schedule primitives. + */ + bool enable_check; void VisitAttrs(AttrVisitor* v) { v->Visit("mod", &mod); // `block_info` is not visited // `stmt2ref` is not visited v->Visit("debug_mask", &debug_mask); + v->Visit("enable_check", &enable_check); } /*! * \brief Replace the part of the AST, as being pointed to by `src_sref`, @@ -194,8 +200,9 @@ class ScheduleState : public ObjectRef { * \param mod The IRModule to be scheduled * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. + * \param enable_check Whether enables prerequisite checks for schedule primitives. */ - TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0); + TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool enable_check = true); /*! \return The mutable pointer to the ScheduleStateNode */ ScheduleStateNode* get() const { return static_cast(data_.get()); } diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 896e2fc48e72..365fb8c2bbbb 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -81,6 +81,14 @@ def _parse_error_render_level(error_render_level: str) -> int: return _ERROR_RENDER_LEVEL.get(error_render_level) +def _parse_enable_checks(enable_checks: bool) -> bool: + if not isinstance(enable_checks, bool): + raise TypeError( + "enable_checks only accepts bool value, got {} instead".format(type(enable_checks)) + ) + return enable_checks + + def _parse_seed(seed: Optional[int]) -> int: if seed is None: return -1 @@ -114,6 +122,7 @@ def __init__( seed: Optional[int] = None, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", + enable_check: bool = True, ) -> None: """Construct a TensorIR schedule class from an IRModule @@ -137,6 +146,15 @@ def __init__( - "detail": Render a detailed error message, with the TIR and error locations printed - "fast: Show a simple error message without rendering or string manipulation - "none": Do not show any error message. + enable_check : bool = True + The default schedule checks are too strict and might prevent us performing some valid + schedules. `enable_check` is an argument to control whether we enable prerequisite + checks for some schedule primitives or not: + - true: perform prerequisite check before applying some schedules. + - false: do not perform some check before applying schedules, but still raise error + if schedule fails. + + It's user duty to guarantee schedule correctness if `enable_check` is set to `False`. Note ---- @@ -151,6 +169,7 @@ def __init__( _parse_seed(seed), _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), + _parse_enable_checks(enable_check), ) @staticmethod @@ -160,6 +179,7 @@ def _create_non_traced( seed: Optional[int] = None, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", + enable_check: bool = True, ) -> "Schedule": """Construct a non-traced TensorIR schedule class from an IRModule.""" return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member @@ -167,6 +187,7 @@ def _create_non_traced( _parse_seed(seed), _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), + _parse_enable_checks(enable_check), ) ########## Utilities ########## diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index fbf21843e7b3..dab84b2fcc6e 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -70,6 +70,14 @@ def _parse_debug_mask(debug_mask: Union[str, int]) -> int: return debug_mask +def _parse_enable_checks(enable_checks: bool) -> bool: + if not isinstance(enable_checks, bool): + raise TypeError( + "enable_checks only accepts bool value, got {} instead".format(type(enable_checks)) + ) + return enable_checks + + @register_object("tir.ScheduleState") class ScheduleState(Object): """The state of scheduling, which exposes a `Replace` method as @@ -81,6 +89,7 @@ class ScheduleState(Object): 3) The dependency information of each block scope (block_info) 4) A reverse mapping from the AST nodes to that in the sref tree (get_sref) 5) A debug flag, if set, extra checking is enabled (debug_mask) + 6) A enable check flag, if False, some prerequisite checks are disabled. Parameters ---------- @@ -89,6 +98,9 @@ class ScheduleState(Object): debug_mask : int Do extra correctness checking after the object construction and each time after calling the Replace method. + enable_check : bool + Indicates whether we enable prerequisite checks for some schedule primitives or not, + defaults to `True`. """ mod: IRModule @@ -99,6 +111,7 @@ def __init__( mod: Union[PrimFunc, IRModule], *, debug_mask: Union[str, int] = "none", + enable_check: bool = True, ) -> None: """Construct a schedule state from an IRModule or a PrimFunc @@ -118,6 +131,7 @@ def __init__( _ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member _parse_mod(mod), _parse_debug_mask(debug_mask), + _parse_enable_checks(enable_check), ) def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]: diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 744801596ebd..b35d64f125d8 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -103,7 +103,7 @@ Definition of a scope that is a stage pipeline: } } // Step 2. Handle `require_stage_pipeline` - if (require_stage_pipeline) { + if (require_stage_pipeline && self->enable_check) { bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; if (stage_pipeline == false) { const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 5a9dab4854bd..75a8fc0a145e 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -24,9 +24,10 @@ namespace tvm { namespace tir { Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, ScheduleErrorRenderLevel error_render_level) { + int debug_mask, ScheduleErrorRenderLevel error_render_level, + bool enable_check) { ObjectPtr n = make_object(); - n->state_ = ScheduleState(mod, debug_mask); + n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); @@ -60,6 +61,7 @@ class ScheduleCopier { n->block_info = copier.Copy(src_state->block_info); n->stmt2ref = copier.Copy(src_state->stmt2ref); n->debug_mask = src_state->debug_mask; + n->enable_check = src_state->enable_check; *new_state = ScheduleState(std::move(n)); *new_symbol_table = copier.Copy(self->symbol_table_); } diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index cc8cb55fd3fa..02d8866e8e9d 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -157,7 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref * parallelized/vectorized/bound. */ // Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow. - CheckSubtreeCompactDataflow(self, loop_sref); + if (self->enable_check) { + CheckSubtreeCompactDataflow(self, loop_sref); + } // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index bb43df1ce914..d39252f3cebe 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -188,17 +188,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, // Get the outer loops from high to low Array loops = GetLoops(block_sref); const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); - // Cond 0. Check loop_sref is an ancestor of block_sref - if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { - throw LoopPositionError(self->mod, GetRef(loop), GetRef(block), - "decompose_reduction"); - } - // Cond 1. Check block is reduction StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - CheckReductionBlock(self, block_sref, scope_root_sref); - // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction - LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); + if (self->enable_check) { + // Cond 0. Check loop_sref is an ancestor of block_sref + if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { + throw LoopPositionError(self->mod, GetRef(loop), GetRef(block), + "decompose_reduction"); + } + // Cond 1. Check block is reduction + CheckReductionBlock(self, block_sref, scope_root_sref); + // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction + LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); + } // IR Manipulation ObjectPtr init_block = make_object(); ObjectPtr init_realize = make_object(); @@ -1176,7 +1178,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax const Block& block = block_realize->block; StmtSRef scope_root = GetScopeRoot(self, block_sref, // /*require_stage_pipeline=*/true); - CheckReductionBlock(self, block_sref, scope_root); + if (self->enable_check) { + CheckReductionBlock(self, block_sref, scope_root); + } const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref); if (rf_loop->kind != ForKind::kSerial) { throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); @@ -1199,8 +1203,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // - the outermost loop should have the reduction block as its first child block; // - the outermost loop that is touched by some reduction block iters can only have one child // block. - LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars, - reduce_loop_vars); + if (self->enable_check) { + LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars, + reduce_loop_vars); + } // Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index cb8b5a1d7787..dcaa61e1bb20 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -65,15 +65,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, int error_render_level) -> Schedule { + int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Concrete(mod, debug_mask, seed, - static_cast(error_render_level)); + static_cast(error_render_level), + enable_check); }); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, int error_render_level) -> Schedule { + int debug_mask, int error_render_level, bool enable_check) -> Schedule { return Schedule::Traced(mod, seed, debug_mask, - static_cast(error_render_level)); + static_cast(error_render_level), + enable_check); }); /******** (FFI) Lookup random variables ********/ diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index a901eff6f2d1..a7a1c0d482cc 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -402,16 +402,21 @@ class BlockInfoCollector : private StmtVisitor { class StateCreator : private StmtVisitor { public: /*! - * \brief The entry function - * \param self The schedule state to be completed + * \brief ScheduleState Creator + * \param mod The module being scheduled. + * \param debug_mask Do extra correctness checking after the class creation + * and each time after calling the Replace method. + * \param enable_check Whether to enable prequisite checks for schedule primitives. */ - static ObjectPtr Create(IRModule mod, int debug_mask) { + static ObjectPtr Create(IRModule mod, int debug_mask, bool enable_check) { ObjectPtr n = make_object(); ScheduleStateNode* self = n.get(); // Set `n->mod` n->mod = std::move(mod); // Set `n->debug_mask` n->debug_mask = debug_mask; + // Set `n->enable_check` + n->enable_check = enable_check; // Set `n->stmt2ref` and `n->block_info` StateCreator creator(self); for (const auto& kv : n->mod->functions) { @@ -426,6 +431,10 @@ class StateCreator : private StmtVisitor { } private: + /*! + * \brief The entry function + * \param self The schedule state to be completed + */ explicit StateCreator(ScheduleStateNode* self) : self_(self) {} /*! @@ -481,9 +490,9 @@ class StateCreator : private StmtVisitor { /**************** Constructor ****************/ -ScheduleState::ScheduleState(IRModule mod, int debug_mask) { +ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) { CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported"; - data_ = StateCreator::Create(mod, debug_mask); + data_ = StateCreator::Create(mod, debug_mask, enable_check); } /**************** Replace ****************/ @@ -1108,8 +1117,8 @@ TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& bl TVM_REGISTER_NODE_TYPE(ScheduleStateNode); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState") - .set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState { - return ScheduleState(mod, debug_mask); + .set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState { + return ScheduleState(mod, debug_mask, enable_check); }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope") .set_body_method(&ScheduleStateNode::GetBlockScope); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index a5cb66a0cb44..1ccc82f302dc 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -22,9 +22,10 @@ namespace tvm { namespace tir { Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, - int debug_mask, ScheduleErrorRenderLevel error_render_level) { + int debug_mask, ScheduleErrorRenderLevel error_render_level, + bool enable_check) { ObjectPtr n = make_object(); - n->state_ = ScheduleState(mod, debug_mask); + n->state_ = ScheduleState(mod, debug_mask, enable_check); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique();