Skip to content

Commit 542274d

Browse files
authored
[Schedule] Add an optional argument disable_checks for Schedule (#14281)
# Motivation Currently, some of the schedule checks are too strict, which makes it hard to schedule some workloads such as FlashAttention whose reduction is two-stage and does not strictly follows our standard. This PR adds an optional argument `disable_checks` which mutes some checks. The argument defaults to `False` and we can enable it whenever we want to disable some `soft` checks (by `soft` we mean if we violate such checks, the schedule is not necessarily invalid, and if we violate `hard` checks the schedule step is invalid). In the future, we should collect the `soft` and `hard` checks for all schedule primitives. This PR serves for FlashAttention and only cares `bind` and some reduction primitives for now.
1 parent a5ed21d commit 542274d

File tree

11 files changed

+102
-35
lines changed

11 files changed

+102
-35
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -779,29 +779,32 @@ class Schedule : public runtime::ObjectRef {
779779
* \param debug_mask Do extra correctness checking after the class creation
780780
* and each time after calling the Replace method.
781781
* \param error_render_level The level of error rendering
782+
* \param enable_check Whether to enable some prequisite checks for schedule primitives, it's
783+
* user's duty to guarantee the schedule correctness if we disable the checks.
782784
* \return The concrete schedule created
783785
* \sa ScheduleDebugMask
784-
* \note The checks performed includes:
785-
* 1) VerifySRefTree
786-
* 2) VerifyCachedFlags
786+
* \note The checks performed includes: 1) VerifySRefTree 2) VerifyCachedFlags
787787
*/
788788
TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
789-
int debug_mask, ScheduleErrorRenderLevel error_render_level);
789+
int debug_mask, ScheduleErrorRenderLevel error_render_level,
790+
bool enable_check = true);
790791
/*!
791792
* \brief Construct a traced concrete TensorIR schedule from an IRModule
792793
* \param mod The IRModule to be scheduled
793794
* \param seed The seed value for schedule's random state
794795
* \param debug_mask Do extra correctness checking after the class creation
795796
* and each time after calling the Replace method.
796797
* \param error_render_level The level of error rendering
798+
* \param enable_check Whether to enable prequisite checks for schedule primitives.
797799
* \return The concrete schedule created
798800
* \sa ScheduleDebugMask
799801
* \note The checks performed include:
800802
* 1) VerifySRefTree
801803
* 2) VerifyCachedFlags
802804
*/
803805
TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
804-
int debug_mask, ScheduleErrorRenderLevel error_render_level);
806+
int debug_mask, ScheduleErrorRenderLevel error_render_level,
807+
bool enable_check = true);
805808
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
806809
};
807810

include/tvm/tir/schedule/state.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ enum ScheduleDebugMask : uint32_t {
8181
* 3) The dependency information of each block scope (block_info)
8282
* 4) A reverse mapping from the AST nodes to that in the sref tree (stmt2ref)
8383
* 5) A debug flag, if set, extra checking is enabled (debug_mask)
84+
* 6) A check flag, if set, enable prequisite check for schedule primitives (enable_check)
8485
*/
8586
class ScheduleStateNode : public Object {
8687
public:
@@ -100,12 +101,17 @@ class ScheduleStateNode : public Object {
100101
* \sa ScheduleDebugMask
101102
*/
102103
int debug_mask;
104+
/*!
105+
* \brief Whether to enable prequisite checks for schedule primitives.
106+
*/
107+
bool enable_check;
103108

104109
void VisitAttrs(AttrVisitor* v) {
105110
v->Visit("mod", &mod);
106111
// `block_info` is not visited
107112
// `stmt2ref` is not visited
108113
v->Visit("debug_mask", &debug_mask);
114+
v->Visit("enable_check", &enable_check);
109115
}
110116
/*!
111117
* \brief Replace the part of the AST, as being pointed to by `src_sref`,
@@ -194,8 +200,9 @@ class ScheduleState : public ObjectRef {
194200
* \param mod The IRModule to be scheduled
195201
* \param debug_mask Do extra correctness checking after the class creation
196202
* and each time after calling the Replace method.
203+
* \param enable_check Whether enables prerequisite checks for schedule primitives.
197204
*/
198-
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0);
205+
TVM_DLL explicit ScheduleState(IRModule mod, int debug_mask = 0, bool enable_check = true);
199206

200207
/*! \return The mutable pointer to the ScheduleStateNode */
201208
ScheduleStateNode* get() const { return static_cast<ScheduleStateNode*>(data_.get()); }

python/tvm/tir/schedule/schedule.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def _parse_error_render_level(error_render_level: str) -> int:
8181
return _ERROR_RENDER_LEVEL.get(error_render_level)
8282

8383

84+
def _parse_enable_checks(enable_checks: bool) -> bool:
85+
if not isinstance(enable_checks, bool):
86+
raise TypeError(
87+
"enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
88+
)
89+
return enable_checks
90+
91+
8492
def _parse_seed(seed: Optional[int]) -> int:
8593
if seed is None:
8694
return -1
@@ -114,6 +122,7 @@ def __init__(
114122
seed: Optional[int] = None,
115123
debug_mask: Union[str, int] = "none",
116124
error_render_level: str = "detail",
125+
enable_check: bool = True,
117126
) -> None:
118127
"""Construct a TensorIR schedule class from an IRModule
119128
@@ -137,6 +146,15 @@ def __init__(
137146
- "detail": Render a detailed error message, with the TIR and error locations printed
138147
- "fast: Show a simple error message without rendering or string manipulation
139148
- "none": Do not show any error message.
149+
enable_check : bool = True
150+
The default schedule checks are too strict and might prevent us performing some valid
151+
schedules. `enable_check` is an argument to control whether we enable prerequisite
152+
checks for some schedule primitives or not:
153+
- true: perform prerequisite check before applying some schedules.
154+
- false: do not perform some check before applying schedules, but still raise error
155+
if schedule fails.
156+
157+
It's user duty to guarantee schedule correctness if `enable_check` is set to `False`.
140158
141159
Note
142160
----
@@ -151,6 +169,7 @@ def __init__(
151169
_parse_seed(seed),
152170
_parse_debug_mask(debug_mask),
153171
_parse_error_render_level(error_render_level),
172+
_parse_enable_checks(enable_check),
154173
)
155174

156175
@staticmethod
@@ -160,13 +179,15 @@ def _create_non_traced(
160179
seed: Optional[int] = None,
161180
debug_mask: Union[str, int] = "none",
162181
error_render_level: str = "detail",
182+
enable_check: bool = True,
163183
) -> "Schedule":
164184
"""Construct a non-traced TensorIR schedule class from an IRModule."""
165185
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
166186
_parse_mod(mod),
167187
_parse_seed(seed),
168188
_parse_debug_mask(debug_mask),
169189
_parse_error_render_level(error_render_level),
190+
_parse_enable_checks(enable_check),
170191
)
171192

172193
########## Utilities ##########

python/tvm/tir/schedule/state.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def _parse_debug_mask(debug_mask: Union[str, int]) -> int:
7070
return debug_mask
7171

7272

73+
def _parse_enable_checks(enable_checks: bool) -> bool:
74+
if not isinstance(enable_checks, bool):
75+
raise TypeError(
76+
"enable_checks only accepts bool value, got {} instead".format(type(enable_checks))
77+
)
78+
return enable_checks
79+
80+
7381
@register_object("tir.ScheduleState")
7482
class ScheduleState(Object):
7583
"""The state of scheduling, which exposes a `Replace` method as
@@ -81,6 +89,7 @@ class ScheduleState(Object):
8189
3) The dependency information of each block scope (block_info)
8290
4) A reverse mapping from the AST nodes to that in the sref tree (get_sref)
8391
5) A debug flag, if set, extra checking is enabled (debug_mask)
92+
6) A enable check flag, if False, some prerequisite checks are disabled.
8493
8594
Parameters
8695
----------
@@ -89,6 +98,9 @@ class ScheduleState(Object):
8998
debug_mask : int
9099
Do extra correctness checking after the object construction
91100
and each time after calling the Replace method.
101+
enable_check : bool
102+
Indicates whether we enable prerequisite checks for some schedule primitives or not,
103+
defaults to `True`.
92104
"""
93105

94106
mod: IRModule
@@ -99,6 +111,7 @@ def __init__(
99111
mod: Union[PrimFunc, IRModule],
100112
*,
101113
debug_mask: Union[str, int] = "none",
114+
enable_check: bool = True,
102115
) -> None:
103116
"""Construct a schedule state from an IRModule or a PrimFunc
104117
@@ -118,6 +131,7 @@ def __init__(
118131
_ffi_api.ScheduleState, # type: ignore # pylint: disable=no-member
119132
_parse_mod(mod),
120133
_parse_debug_mask(debug_mask),
134+
_parse_enable_checks(enable_check),
121135
)
122136

123137
def get_sref(self, stmt: Union[Block, For]) -> Optional[StmtSRef]:

src/tir/schedule/analysis/analysis.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ Definition of a scope that is a stage pipeline:
103103
}
104104
}
105105
// Step 2. Handle `require_stage_pipeline`
106-
if (require_stage_pipeline) {
106+
if (require_stage_pipeline && self->enable_check) {
107107
bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
108108
if (stage_pipeline == false) {
109109
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);

src/tir/schedule/concrete_schedule.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ namespace tvm {
2424
namespace tir {
2525

2626
Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
27-
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
27+
int debug_mask, ScheduleErrorRenderLevel error_render_level,
28+
bool enable_check) {
2829
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
29-
n->state_ = ScheduleState(mod, debug_mask);
30+
n->state_ = ScheduleState(mod, debug_mask, enable_check);
3031
n->error_render_level_ = error_render_level;
3132
n->symbol_table_ = {};
3233
n->analyzer_ = std::make_unique<arith::Analyzer>();
@@ -60,6 +61,7 @@ class ScheduleCopier {
6061
n->block_info = copier.Copy(src_state->block_info);
6162
n->stmt2ref = copier.Copy(src_state->stmt2ref);
6263
n->debug_mask = src_state->debug_mask;
64+
n->enable_check = src_state->enable_check;
6365
*new_state = ScheduleState(std::move(n));
6466
*new_symbol_table = copier.Copy(self->symbol_table_);
6567
}

src/tir/schedule/primitive/for_kind.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref
157157
* parallelized/vectorized/bound.
158158
*/
159159
// Step 1. Check whether the subtree rooted from the `loop` in sref tree has compact data flow.
160-
CheckSubtreeCompactDataflow(self, loop_sref);
160+
if (self->enable_check) {
161+
CheckSubtreeCompactDataflow(self, loop_sref);
162+
}
161163

162164
// Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each
163165
// underlying block.

src/tir/schedule/primitive/reduction.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
188188
// Get the outer loops from high to low
189189
Array<StmtSRef> loops = GetLoops(block_sref);
190190
const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get();
191-
// Cond 0. Check loop_sref is an ancestor of block_sref
192-
if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
193-
throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
194-
"decompose_reduction");
195-
}
196-
// Cond 1. Check block is reduction
197191
StmtSRef scope_root_sref = GetScopeRoot(self, block_sref,
198192
/*require_stage_pipeline=*/false);
199-
CheckReductionBlock(self, block_sref, scope_root_sref);
200-
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
201-
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
193+
if (self->enable_check) {
194+
// Cond 0. Check loop_sref is an ancestor of block_sref
195+
if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) {
196+
throw LoopPositionError(self->mod, GetRef<For>(loop), GetRef<Block>(block),
197+
"decompose_reduction");
198+
}
199+
// Cond 1. Check block is reduction
200+
CheckReductionBlock(self, block_sref, scope_root_sref);
201+
// Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction
202+
LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref);
203+
}
202204
// IR Manipulation
203205
ObjectPtr<BlockNode> init_block = make_object<BlockNode>();
204206
ObjectPtr<BlockRealizeNode> init_realize = make_object<BlockRealizeNode>();
@@ -1176,7 +1178,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
11761178
const Block& block = block_realize->block;
11771179
StmtSRef scope_root = GetScopeRoot(self, block_sref, //
11781180
/*require_stage_pipeline=*/true);
1179-
CheckReductionBlock(self, block_sref, scope_root);
1181+
if (self->enable_check) {
1182+
CheckReductionBlock(self, block_sref, scope_root);
1183+
}
11801184
const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop_sref);
11811185
if (rf_loop->kind != ForKind::kSerial) {
11821186
throw NotSerialLoopKindError(self->mod, GetRef<For>(rf_loop));
@@ -1199,8 +1203,10 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
11991203
// - the outermost loop should have the reduction block as its first child block;
12001204
// - the outermost loop that is touched by some reduction block iters can only have one child
12011205
// block.
1202-
LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
1203-
reduce_loop_vars);
1206+
if (self->enable_check) {
1207+
LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars,
1208+
reduce_loop_vars);
1209+
}
12041210

12051211
// Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the
12061212
// commutative reducer, combiner lhs and combiner rhs from the reduction identity and the

src/tir/schedule/schedule.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,17 @@ TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV
6565
TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); });
6666
TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule")
6767
.set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
68-
int debug_mask, int error_render_level) -> Schedule {
68+
int debug_mask, int error_render_level, bool enable_check) -> Schedule {
6969
return Schedule::Concrete(mod, debug_mask, seed,
70-
static_cast<ScheduleErrorRenderLevel>(error_render_level));
70+
static_cast<ScheduleErrorRenderLevel>(error_render_level),
71+
enable_check);
7172
});
7273
TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule")
7374
.set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed,
74-
int debug_mask, int error_render_level) -> Schedule {
75+
int debug_mask, int error_render_level, bool enable_check) -> Schedule {
7576
return Schedule::Traced(mod, seed, debug_mask,
76-
static_cast<ScheduleErrorRenderLevel>(error_render_level));
77+
static_cast<ScheduleErrorRenderLevel>(error_render_level),
78+
enable_check);
7779
});
7880

7981
/******** (FFI) Lookup random variables ********/

src/tir/schedule/state.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -402,16 +402,21 @@ class BlockInfoCollector : private StmtVisitor {
402402
class StateCreator : private StmtVisitor {
403403
public:
404404
/*!
405-
* \brief The entry function
406-
* \param self The schedule state to be completed
405+
* \brief ScheduleState Creator
406+
* \param mod The module being scheduled.
407+
* \param debug_mask Do extra correctness checking after the class creation
408+
* and each time after calling the Replace method.
409+
* \param enable_check Whether to enable prequisite checks for schedule primitives.
407410
*/
408-
static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask) {
411+
static ObjectPtr<ScheduleStateNode> Create(IRModule mod, int debug_mask, bool enable_check) {
409412
ObjectPtr<ScheduleStateNode> n = make_object<ScheduleStateNode>();
410413
ScheduleStateNode* self = n.get();
411414
// Set `n->mod`
412415
n->mod = std::move(mod);
413416
// Set `n->debug_mask`
414417
n->debug_mask = debug_mask;
418+
// Set `n->enable_check`
419+
n->enable_check = enable_check;
415420
// Set `n->stmt2ref` and `n->block_info`
416421
StateCreator creator(self);
417422
for (const auto& kv : n->mod->functions) {
@@ -426,6 +431,10 @@ class StateCreator : private StmtVisitor {
426431
}
427432

428433
private:
434+
/*!
435+
* \brief The entry function
436+
* \param self The schedule state to be completed
437+
*/
429438
explicit StateCreator(ScheduleStateNode* self) : self_(self) {}
430439

431440
/*!
@@ -481,9 +490,9 @@ class StateCreator : private StmtVisitor {
481490

482491
/**************** Constructor ****************/
483492

484-
ScheduleState::ScheduleState(IRModule mod, int debug_mask) {
493+
ScheduleState::ScheduleState(IRModule mod, int debug_mask, bool enable_check) {
485494
CHECK_GE(debug_mask, -1) << "ValueError: negative `debug_mask` other than -1 is not supported";
486-
data_ = StateCreator::Create(mod, debug_mask);
495+
data_ = StateCreator::Create(mod, debug_mask, enable_check);
487496
}
488497

489498
/**************** Replace ****************/
@@ -1108,8 +1117,8 @@ TVM_DLL Array<Bool> GetCachedFlags(const ScheduleState& self, const StmtSRef& bl
11081117

11091118
TVM_REGISTER_NODE_TYPE(ScheduleStateNode);
11101119
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleState")
1111-
.set_body_typed([](IRModule mod, int debug_mask) -> ScheduleState {
1112-
return ScheduleState(mod, debug_mask);
1120+
.set_body_typed([](IRModule mod, int debug_mask, bool enable_check) -> ScheduleState {
1121+
return ScheduleState(mod, debug_mask, enable_check);
11131122
});
11141123
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStateGetBlockScope")
11151124
.set_body_method<ScheduleState>(&ScheduleStateNode::GetBlockScope);

0 commit comments

Comments
 (0)