Skip to content

Commit a27315c

Browse files
cblmemojunrushao
authored andcommitted
[MetaSchedule] Introduce Async Pipeline in MultiLevelTiling
This PR introduces async pipeline in the current TVM's MultiLevelTiling Rules. This PR is blocking on apache#13966 since some conv2d workload will use `tir.if_then_else` to pad the input to the correct size, and this PR uses async copy in such copy statement. 1. Add a subrule in `src/meta_schedule/schedule_rule/multi_level_tiling.h/.cc` that annotate async copy for mlt. In CUDA Core, this PR has a perf boost of around 1T GFLOP/s in most Conv2d test cases and 1T ~ 2T in most GEMM test cases. All generated codes, scripts, and traces are available at https://github.com/Rainy-Memory/tvm-async-rule-benchmark. Currently tested on commit `afbfb7aa7e43732cb716f8e443df696110be6afc` in conv2d NHWC workload, with a RTX 3080 GPU. Workload: Conv2d NHWC |Shape|Mainline TVM|Mainline TVM with Async| |-|-|-| |N=1_H=224_W=224_C=3_K=64_R=7_S=7_STR=2_PAD=3_DIL=1|13838.05219|14687.89452| |N=1_H=56_W=56_C=64_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|5398.305085|5613.892553| |N=1_H=56_W=56_C=64_K=64_R=3_S=3_STR=1_PAD=1_DIL=1|11652.96825|13157.88249| |N=1_H=56_W=56_C=64_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|10638.8309|11674.68499| |N=1_H=56_W=56_C=256_K=64_R=1_S=1_STR=1_PAD=0_DIL=1|8692.32829|9469.264089| |N=1_H=56_W=56_C=256_K=128_R=1_S=1_STR=2_PAD=0_DIL=1|4685.767442|5698.19634| |N=1_H=28_W=28_C=128_K=128_R=3_S=3_STR=1_PAD=1_DIL=1|9872.787087|10404.60405| |N=1_H=28_W=28_C=128_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|9974.281496|10073.31657| |N=1_H=28_W=28_C=512_K=128_R=1_S=1_STR=1_PAD=0_DIL=1|7075.866932|8564.572712| |N=1_H=28_W=28_C=512_K=256_R=1_S=1_STR=2_PAD=0_DIL=1|3648.330914|4021.923142| |N=1_H=14_W=14_C=256_K=256_R=3_S=3_STR=1_PAD=1_DIL=1|8192.954618|9160.182054| |N=1_H=14_W=14_C=256_K=1024_R=1_S=1_STR=1_PAD=0_DIL=1|8008.870153|9362.825279| |N=1_H=14_W=14_C=1024_K=256_R=1_S=1_STR=1_PAD=0_DIL=1|5210.062241|6051.208379| |N=1_H=14_W=14_C=1024_K=512_R=1_S=1_STR=2_PAD=0_DIL=1|2550.787202|3587.902938| |N=1_H=7_W=7_C=512_K=512_R=3_S=3_STR=1_PAD=1_DIL=1|4350.626084|5432.788068| |N=1_H=7_W=7_C=512_K=2048_R=1_S=1_STR=1_PAD=0_DIL=1|6672.068026|7663.725217| |N=1_H=7_W=7_C=2048_K=512_R=1_S=1_STR=1_PAD=0_DIL=1|3142.564263|4297.988014| Workload: GEMM NN |Shape|Mainline TVM|Mainline TVM with Async| |-|-|-| |M=512_N=256_K=640|8678.46|10607.37| |M=512_N=384_K=256|8109.13|10290.72| |M=512_N=512_K=512|11419.83|14000.86| |M=512_N=3072_K=768|19709.39|18351.61| |M=512_N=768_K=3072|12844.59|13730.88| |M=896_N=896_K=896|16149.91|16131.39| |M=1024_N=1024_K=1024|18842.11|19662.8| |M=1152_N=1152_K=1152|15386.79|16736.1| |M=1536_N=1536_K=1536|18522.67|18872.06| |M=2048_N=2048_K=2048|19515.42|18874.85| |M=3072_N=3072_K=3072|19233.9|19291.42| |M=4096_N=4096_K=4096|17122.17|19259.01|
1 parent d7253fb commit a27315c

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/meta_schedule/schedule_rule/multi_level_tiling.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,21 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context)
8787
TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target";
8888
}
8989
}
90+
if (Optional<String> opt_sm = context->target.value()->GetAttr<String>("arch")) {
91+
std::string sm = opt_sm.value();
92+
if (support::StartsWith(sm, "sm_")) {
93+
sm = sm.substr(3);
94+
try {
95+
// only sm_80 or higher supports async memcopy
96+
if (std::stoi(sm) >= 80) {
97+
this->stages.insert(this->stages.end(), {4, 5});
98+
}
99+
} catch (const std::invalid_argument& e) {
100+
LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm
101+
<< ". Details: " << e.what();
102+
}
103+
}
104+
}
90105
logger = context->logger;
91106
}
92107

@@ -115,6 +130,9 @@ std::vector<State> MultiLevelTilingNode::ApplySubRules(std::vector<State> states
115130
states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); });
116131
states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); });
117132
states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); });
133+
states = SubRule(std::move(states), [&](State state) {
134+
return AddAsyncPipeline(std::move(state));
135+
});
118136
return states;
119137
}
120138

@@ -280,6 +298,43 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
280298
return results;
281299
}
282300

301+
std::vector<State> MultiLevelTilingNode::AddAsyncPipeline(State state) const {
302+
// For arch that does not support async pipeline, this->stages will be an empty vector
303+
if (r_indices_.size() < 1 || this->stages.empty()) {
304+
return {state};
305+
}
306+
// Current only support default config used by ScheduleRule::DefaultCUDA
307+
// @see src/meta_schedule/schedule_rule/schedule_rule.cc
308+
// check the reduce loop contains exactly 3 for loops
309+
// therefore it matches the notation array size in the following code
310+
tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back());
311+
const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref);
312+
Array<tir::Stmt> seq = Downcast<tir::SeqStmt>(r_for_loop->body)->seq;
313+
if (seq.size() != 3) {
314+
return {state};
315+
}
316+
for (auto& stmt : seq) {
317+
if (!stmt.as<tir::ForNode>()) {
318+
return {state};
319+
}
320+
}
321+
322+
LoopRV r_loop_fused = state->sch->Fuse(state->tiles[r_indices_[0]]);
323+
std::vector<State> ret;
324+
ret.push_back(state);
325+
for (int stage : this->stages) {
326+
State new_state = state->Copy();
327+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage,
328+
Array<Integer>{0, 0, stage - 2});
329+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order,
330+
Array<Integer>{0, 1, 2});
331+
new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages,
332+
Array<Integer>{0});
333+
ret.push_back(std::move(new_state));
334+
}
335+
return ret;
336+
}
337+
283338
void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
284339
const tir::BlockRV& block) const {
285340
// Filter out invalid vector lanes according to the data type.

src/meta_schedule/schedule_rule/multi_level_tiling.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
148148
std::vector<State> TileLoopNest(State state) const;
149149
// SubRule 3. add read cache
150150
std::vector<State> AddReadReuse(State state) const;
151+
// SubRule 4. add async pipeline
152+
std::vector<State> AddAsyncPipeline(State state) const;
151153

152154
// Do nothing; Inherited from ScheduleRuleNode
153155
void InitializeWithTuneContext(const TuneContext& context) final;
@@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
192194
int thread_warp_size_;
193195
/*! \brief The maximum number of threads to be used size of a thread warp */
194196
int max_threads_per_block_;
197+
/*! \brief All available async pipeline stages. */
198+
std::vector<int> stages;
195199
/*! \brief The logging function */
196200
PackedFunc logger;
197201
/*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */

0 commit comments

Comments
 (0)