Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization = "kCustomWarpSpecialization";
} // namespace attr

static constexpr const char *kDebugMergeSharedMemoryAllocations =
Expand Down
19 changes: 14 additions & 5 deletions src/transform/annotate_warp_group_reg_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class SetMaxNRegCollector : public StmtExprVisitor {
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
if (collector.warp_specialized_) {
return Array<IntImm>({});
}
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
Expand All @@ -43,21 +46,27 @@ class SetMaxNRegCollector : public StmtExprVisitor {
StmtExprVisitor::VisitStmt_(op);
}

void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == attr::kCustomWarpSpecialization) {
warp_specialized_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}

Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
bool warp_specialized_ = false;
};

class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (warp_specialized) {
// Should handle set_max_nreg when using hand-written warp specialized
return f;
}
auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f);
if (T.nreg_.size() == 0) {
return f;
}
f.CopyOnWrite()->body = T(f->body);
return f;
}
Expand Down
5 changes: 4 additions & 1 deletion src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,11 @@ tvm::transform::Pass WarpSpecialized() {
if (!warp_specialized) {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect);
} else {
ObjectRef node = String("default");
f.CopyOnWrite()->body = AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f;
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
Expand Down
Loading