Skip to content

Commit 19b2252

Browse files
authored
[Enhancement] Introduce PassConfig TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE to enable aggressive shared memory reuse (#602)
* [Enhancement] Add aggressive shared memory merge option in memory allocation - Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations. - Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse. - Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration. - Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately. * lint fix * [Enhancement] Add aggressive shared memory merge configuration option - Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities. * [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option - Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management. - Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target. - Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.
1 parent 1ce4757 commit 19b2252

File tree

6 files changed

+64
-18
lines changed

6 files changed

+64
-18
lines changed

src/op/builtin.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWarpSpecialized, Bool);
2626
TVM_REGISTER_PASS_CONFIG_OPTION(kConfigIndexBitwidth, Integer);
2727
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableDynamicTailSplit, Bool);
2828
TVM_REGISTER_PASS_CONFIG_OPTION(kDynamicAlignment, Integer);
29+
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableAggressiveSharedMemoryMerge, Bool);
2930

3031
#define TIR_DEFINE_TL_BUILTIN(OpName) \
3132
const Op &OpName() { \

src/op/builtin.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ static constexpr const char *kDisableSafeMemoryLegalize =
2828
static constexpr const char *kDisableWarpSpecialized =
2929
"tl.disable_warp_specialized";
3030
static constexpr const char *kConfigIndexBitwidth = "tl.config_index_bitwidth";
31+
static constexpr const char *kEnableAggressiveSharedMemoryMerge =
32+
"tl.enable_aggressive_shared_memory_merge";
3133

3234
/*!
3335
* \brief Whether to disable dynamic tail split

src/transform/merge_shared_memory_allocations.cc

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ class AllocateCollector : public StmtExprVisitor {
9595
//
9696
class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
9797
public:
98-
explicit SharedMemLinearAccessPatternFinder(bool is_dynamic = true,
99-
bool verbose = false)
100-
: is_dynamic_(is_dynamic), verbose_(verbose) {}
98+
explicit SharedMemLinearAccessPatternFinder(
99+
bool is_dynamic = true, bool enable_aggressive_merge = false,
100+
bool verbose = false)
101+
: is_dynamic_(is_dynamic),
102+
enable_aggressive_merge_(enable_aggressive_merge), verbose_(verbose) {}
101103
/*! \brief record the touch list of statement. */
102104
struct StmtEntry {
103105
// The statement
@@ -151,9 +153,15 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
151153
ICHECK_LT(it->second.level, scope_.size());
152154
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
153155
// set into scope_.size() - 1 for aggressive memory reuse
154-
scope_[it->second.level].touched.push_back(buf);
156+
auto enable_aggressive_merge = enable_aggressive_merge_;
157+
if (enable_aggressive_merge) {
158+
scope_[scope_.size() - 1].touched.push_back(buf);
159+
} else {
160+
scope_[it->second.level].touched.push_back(buf);
161+
}
155162
}
156163
}
164+
157165
StmtEntry e = scope_.back();
158166
scope_.pop_back();
159167
if (e.touched.size() != 0) {
@@ -185,7 +193,12 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
185193
ICHECK_LT(it->second.level, scope_.size())
186194
<< "Load memory in places other than store.";
187195
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
188-
scope_[it->second.level].touched.push_back(buf);
196+
auto enable_aggressive_merge = enable_aggressive_merge_;
197+
if (enable_aggressive_merge) {
198+
scope_[scope_.size() - 1].touched.push_back(buf);
199+
} else {
200+
scope_[it->second.level].touched.push_back(buf);
201+
}
189202
}
190203
}
191204
}
@@ -196,7 +209,12 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
196209
if (it != alloc_info_.end() && it->second.alloc) {
197210
ICHECK_LT(it->second.level, scope_.size());
198211
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
199-
scope_[it->second.level].touched.push_back(buf);
212+
auto enable_aggressive_merge = enable_aggressive_merge_;
213+
if (enable_aggressive_merge) {
214+
scope_[scope_.size() - 1].touched.push_back(buf);
215+
} else {
216+
scope_[it->second.level].touched.push_back(buf);
217+
}
200218
}
201219
}
202220
}
@@ -284,6 +302,8 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
284302
}
285303
// Whether do dyanmic analysis.
286304
bool is_dynamic_{true};
305+
// Whether do aggressive merge.
306+
bool enable_aggressive_merge_{false};
287307
// Whether do verbose logging.
288308
bool verbose_{false};
289309
// Whether already in thread env.
@@ -317,8 +337,9 @@ class SharedMemoryRewriter : public StmtExprMutator {
317337
* \param stmt the statement
318338
*/
319339
void PlanReuse(const Stmt &stmt, bool is_dynamic = true,
320-
bool verbose = false) {
321-
SharedMemLinearAccessPatternFinder finder(is_dynamic, verbose);
340+
bool enable_aggressive_merge = false, bool verbose = false) {
341+
SharedMemLinearAccessPatternFinder finder(is_dynamic,
342+
enable_aggressive_merge, verbose);
322343
finder(stmt);
323344
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
324345
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
@@ -956,6 +977,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
956977
}
957978
// Wheather enable dyanmic analysis.
958979
bool is_dynamic_{true};
980+
959981
// Whether enable verbose logging.
960982
bool verbose_{false};
961983
// The var for the merged buffer
@@ -985,18 +1007,19 @@ class SharedMemoryRewriter : public StmtExprMutator {
9851007
};
9861008

9871009
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
1010+
bool enable_aggressive_merge,
9881011
bool verbose = false) {
9891012
AllocateCollector collector;
9901013
collector(stmt);
9911014
if (collector.dyn_shmem_allocs_.size() > 1) {
9921015
SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_, true, verbose);
993-
rewriter.PlanReuse(stmt);
1016+
rewriter.PlanReuse(stmt, true, enable_aggressive_merge);
9941017
stmt = rewriter(std::move(stmt));
9951018
}
9961019
if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
9971020
SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false,
9981021
verbose);
999-
rewriter.PlanReuse(stmt, false);
1022+
rewriter.PlanReuse(stmt, false, enable_aggressive_merge);
10001023
stmt = rewriter(std::move(stmt));
10011024
}
10021025
return stmt;
@@ -1006,17 +1029,18 @@ using namespace tir::transform;
10061029

10071030
namespace transform {
10081031

1009-
Pass MergeSharedMemoryAllocations() {
1010-
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
1032+
Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false) {
1033+
auto pass_func = [enable_aggressive_merge](PrimFunc f, IRModule m,
1034+
PassContext ctx) {
10111035
bool merge_static_smem =
10121036
ctx->GetConfig<Bool>("tir.merge_static_smem", Bool(false)).value();
10131037
bool debug_merge_shared_memory_allocations =
10141038
ctx->GetConfig<Bool>(kDebugMergeSharedMemoryAllocations, Bool(false))
10151039
.value();
10161040
auto *n = f.CopyOnWrite();
1017-
n->body =
1018-
tl::MergeSharedMemoryAllocations(std::move(n->body), merge_static_smem,
1019-
debug_merge_shared_memory_allocations);
1041+
n->body = tl::MergeSharedMemoryAllocations(
1042+
std::move(n->body), merge_static_smem, enable_aggressive_merge,
1043+
debug_merge_shared_memory_allocations);
10201044
return f;
10211045
};
10221046
return CreatePrimFuncPass(pass_func, 0, "tl.MergeSharedMemoryAllocations",

tilelang/engine/phase.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,20 @@ def allow_global_thread_synchronization(pass_ctx: Optional[PassContext] = None)
5050
return enable_global_thread_sync
5151

5252

53+
def should_enable_aggressive_merge(pass_ctx: Optional[PassContext] = None,
54+
target: Optional[Target] = None) -> bool:
55+
if pass_ctx is None:
56+
pass_ctx = tilelang.transform.get_pass_context()
57+
enable_aggressive_merge = bool(
58+
pass_ctx.config.get(tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, False))
59+
if allow_warp_specialized(pass_ctx=pass_ctx, target=target):
60+
# This is a workaround to avoid the bug in the MergeSharedMemoryAllocations pass
61+
# when warp specialization is enabled, as different warp threads may access different
62+
# buffers, but the liveness analysis is hard because we need to do pipeline.
63+
enable_aggressive_merge = False
64+
return enable_aggressive_merge
65+
66+
5367
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
5468
# Bind the target device information to the module
5569
mod = tir.transform.BindTarget(target)(mod)
@@ -151,7 +165,9 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
151165
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
152166
mod = tir.transform.SplitHostDevice()(mod)
153167

154-
mod = tilelang.transform.MergeSharedMemoryAllocations()(mod)
168+
mod = tilelang.transform.MergeSharedMemoryAllocations(
169+
enable_aggressive_merge=should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target))(
170+
mod)
155171

156172
mod = tilelang.transform.ThreadSync("shared")(mod)
157173
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)

tilelang/transform/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,15 +335,15 @@ def EliminateStorageSyncForMBarrier():
335335
return _ffi_api.EliminateStorageSyncForMBarrier() # type: ignore
336336

337337

338-
def MergeSharedMemoryAllocations():
338+
def MergeSharedMemoryAllocations(enable_aggressive_merge: bool = False):
339339
"""MergeSharedMemoryAllocations
340340
341341
Returns
342342
-------
343343
fpass : tvm.transform.Pass
344344
The result pass
345345
"""
346-
return _ffi_api.MergeSharedMemoryAllocations() # type: ignore
346+
return _ffi_api.MergeSharedMemoryAllocations(enable_aggressive_merge) # type: ignore
347347

348348

349349
def LowerL2Persistent():

tilelang/transform/pass_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ class PassConfigKey(str, Enum):
3232
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
3333
"""Enable debug information for merge shared memory allocations. Default: False"""
3434

35+
TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE = "tl.enable_aggressive_shared_memory_merge"
36+
"""Enable aggressive merge of shared memory allocations. Default: False"""
37+
3538
# TIR related configs
3639
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
3740
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""

0 commit comments

Comments
 (0)