Skip to content

Commit 8b00522

Browse files
authored
[Refactor] Update TVM subproject and refactor BlockNode handling in warp_specialized_rewriter.cc (#812)
* [Feature] Introduce custom warp specialization attribute and enhance warp group register allocation - Added a new attribute `kCustomWarpSpecialization` to support custom warp specialization in the TileLang framework. - Updated the `Collect` method in `SetMaxNRegCollector` to handle cases where warp specialization is detected, returning an empty array accordingly. - Enhanced the `SetMaxNRegInjector` to skip processing when no registers are needed, improving efficiency. - Modified the `WarpSpecialized` pass to include the new attribute in the function body when warp specialization is enabled, ensuring proper handling in transformations. * lint * lint
1 parent 0b3683b commit 8b00522

File tree

3 files changed

+21
-6
lines changed

3 files changed

+21
-6
lines changed

src/op/builtin.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace attr {
2525
static constexpr const char *kPaddingMap = "padding_map";
2626
static constexpr const char *kWarpSpecializationScope =
2727
"kWarpSpecializationScope";
28+
static constexpr const char *kCustomWarpSpecialization =
29+
"kCustomWarpSpecialization";
2830
} // namespace attr
2931

3032
static constexpr const char *kDebugMergeSharedMemoryAllocations =

src/transform/annotate_warp_group_reg_alloc.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ class SetMaxNRegCollector : public StmtExprVisitor {
1717
static Array<IntImm> Collect(const PrimFunc &f) {
1818
SetMaxNRegCollector collector;
1919
collector(f->body);
20+
if (collector.warp_specialized_) {
21+
return Array<IntImm>({});
22+
}
2023
return collector.has_no_set_max_nreg_
2124
? Array<IntImm>({IntImm(DataType::Int(32), -1),
2225
IntImm(DataType::Int(32), -1)})
@@ -43,21 +46,27 @@ class SetMaxNRegCollector : public StmtExprVisitor {
4346
StmtExprVisitor::VisitStmt_(op);
4447
}
4548

49+
void VisitStmt_(const AttrStmtNode *op) final {
50+
if (op->attr_key == attr::kCustomWarpSpecialization) {
51+
warp_specialized_ = true;
52+
}
53+
StmtExprVisitor::VisitStmt_(op);
54+
}
55+
4656
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
4757
IntImm(DataType::Int(32), 0)};
4858
bool has_no_set_max_nreg_ = false;
59+
bool warp_specialized_ = false;
4960
};
5061

5162
class SetMaxNRegInjector : public StmtExprMutator {
5263
public:
5364
static PrimFunc Inject(PrimFunc f) {
54-
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
55-
if (warp_specialized) {
56-
// Should handle set_max_nreg when using hand-written warp specialized
57-
return f;
58-
}
5965
auto T = SetMaxNRegInjector();
6066
T.nreg_ = SetMaxNRegCollector::Collect(f);
67+
if (T.nreg_.empty()) {
68+
return f;
69+
}
6170
f.CopyOnWrite()->body = T(f->body);
6271
return f;
6372
}

src/transform/warp_specialized_rewriter.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,12 @@ tvm::transform::Pass WarpSpecialized() {
12831283
if (!warp_specialized) {
12841284
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
12851285
disable_shuffle_elect);
1286+
} else {
1287+
ObjectRef node = String("default");
1288+
f.CopyOnWrite()->body =
1289+
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
1290+
return f;
12861291
}
1287-
return f;
12881292
};
12891293
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
12901294
}

0 commit comments

Comments
 (0)