Skip to content

Commit 36b3097

Browse files
cblmemovinx13junrushaoSiyuan FengMasterJH5574
authored
[MetaSchedule] Introducing MemHammer (#14164)
Introducing MemHammer Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]>
1 parent 0627684 commit 36b3097

File tree

21 files changed

+3687
-0
lines changed

21 files changed

+3687
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,11 @@ class ScheduleNode : public runtime::Object {
466466
*/
467467
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
468468
BufferIndexType buffer_index_type) = 0;
469+
/******** Schedule: Data movement ********/
470+
virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
471+
const String& storage_scope) = 0;
472+
virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
473+
const String& storage_scope) = 0;
469474
/******** Schedule: Compute location ********/
470475
/*!
471476
* \brief Move a producer block under the specific loop, and regenerate the

include/tvm/tir/stmt.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,37 @@ constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layo
16091609
*/
16101610
constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
16111611

1612+
/*!
1613+
* \brief Mark that the block need to add predicate for block var bounds during lowering
1614+
*/
1615+
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1616+
1617+
/*! \brief Mark that tensor core is enabled in the PrimExpr */
1618+
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1619+
1620+
/*!
1621+
* \brief Mark a block as generated by cache_read or cache_write block.
1622+
* 0 means cache_read; 1 means cache_write.
1623+
* \sa meta_schedule_cache_type_read
1624+
* \sa meta_schedule_cache_type_write
1625+
*/
1626+
constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1627+
1628+
/*! \sa meta_schedule_cache_type */
1629+
constexpr const int meta_schedule_cache_type_read = 0;
1630+
1631+
/*! \sa meta_schedule_cache_type */
1632+
constexpr const int meta_schedule_cache_type_write = 1;
1633+
1634+
/*! \brief Mark auto copy for memhammer */
1635+
constexpr const char* auto_copy = "auto_copy";
1636+
1637+
/*! \brief Mark local stage constraint on data copy */
1638+
constexpr const char* local_stage = "local_stage";
1639+
1640+
/*! \brief Mark vectorization length constraint on block */
1641+
constexpr const char* vector_bytes = "vector_bytes";
1642+
16121643
/*!
16131644
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
16141645
* warp size.

include/tvm/tir/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,12 @@ TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
647647
*/
648648
TVM_DLL Pass ExtractPrimFuncConstants();
649649

650+
/*!
651+
* \brief Automatically do memory optimizations for auto copy blocks
652+
* \return The pass.
653+
*/
654+
TVM_DLL Pass LowerAutoCopy();
655+
650656
/*!
651657
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
652658
* \return The pass.

python/tvm/tir/schedule/schedule.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,30 @@ def after_reindex(
16871687
self, block, buffer_index, buffer_index_type_enum
16881688
)
16891689

1690+
########## Schedule: Data movement ##########
1691+
1692+
def read_at(
1693+
self,
1694+
loop: LoopRV,
1695+
block: BlockRV,
1696+
read_buffer_index: int,
1697+
storage_scope: str,
1698+
) -> BlockRV:
1699+
return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member
1700+
self, loop, block, read_buffer_index, storage_scope
1701+
)
1702+
1703+
def write_at(
1704+
self,
1705+
loop: LoopRV,
1706+
block: BlockRV,
1707+
write_buffer_index: int,
1708+
storage_scope: str,
1709+
) -> BlockRV:
1710+
return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member
1711+
self, loop, block, write_buffer_index, storage_scope
1712+
)
1713+
16901714
########## Schedule: Compute location ##########
16911715

16921716
@type_checked

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,17 @@ def ExtractPrimFuncConstants():
926926
return _ffi_api.ExtractPrimFuncConstants() # type: ignore
927927

928928

929+
def LowerAutoCopy():
930+
"""Automatically do memory optimizations for auto copy blocks
931+
932+
Returns
933+
-------
934+
fpass : tvm.transform.Pass
935+
The result pass
936+
"""
937+
return _ffi_api.LowerAutoCopy() # type: ignore
938+
939+
929940
def RenormalizeSplitPattern():
930941
"""Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
931942

src/driver/driver_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
213213
pass_list.push_back(tir::transform::UnifyThreadBinding());
214214
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
215215
pass_list.push_back(tir::transform::CompactBufferAllocation());
216+
pass_list.push_back(tir::transform::LowerAutoCopy());
216217
pass_list.push_back(tir::transform::LowerMatchBuffer());
217218
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
218219
pass_list.push_back(tir::transform::LowerOpaqueBlock());

src/meta_schedule/feature_extractor/per_store_feature.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ Sequential PassListForPerStoreFeature() {
309309
tir::transform::ConvertBlocksToOpaque(),
310310
tir::transform::UnifyThreadBinding(),
311311
tir::transform::CompactBufferAllocation(),
312+
tir::transform::LowerAutoCopy(),
312313
tir::transform::LowerMatchBuffer(),
313314
tir::transform::Simplify(),
314315
});

src/meta_schedule/postproc/verify_gpu_code.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
164164
pass_list.push_back(tir::transform::UnifyThreadBinding());
165165
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
166166
pass_list.push_back(tir::transform::CompactBufferAllocation());
167+
pass_list.push_back(tir::transform::LowerAutoCopy());
167168
pass_list.push_back(tir::transform::LowerMatchBuffer());
168169
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
169170
pass_list.push_back(tir::transform::LowerOpaqueBlock());

src/tir/schedule/concrete_schedule.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,30 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
631631
return CreateRV<BlockRV>(result);
632632
}
633633

634+
/******** Schedule: Data movement ********/
635+
636+
BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv,
637+
int read_buffer_index, const String& storage_scope) {
638+
StmtSRef result{nullptr};
639+
TVM_TIR_SCHEDULE_BEGIN();
640+
result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index,
641+
storage_scope);
642+
TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_);
643+
this->state_->DebugVerify();
644+
return CreateRV<BlockRV>(result);
645+
}
646+
647+
BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv,
648+
int write_buffer_index, const String& storage_scope) {
649+
StmtSRef result{nullptr};
650+
TVM_TIR_SCHEDULE_BEGIN();
651+
result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index,
652+
storage_scope);
653+
TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_);
654+
this->state_->DebugVerify();
655+
return CreateRV<BlockRV>(result);
656+
}
657+
634658
/******** Schedule: Compute location ********/
635659

636660
void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,

src/tir/schedule/concrete_schedule.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class ConcreteScheduleNode : public ScheduleNode {
126126
int cse_thresh) override;
127127
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
128128
BufferIndexType buffer_index_type) override;
129+
/******** Schedule: Data movement ********/
130+
BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
131+
const String& storage_scope) override;
132+
BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
133+
const String& storage_scope) override;
129134
/******** Schedule: Compute location ********/
130135
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
131136
int index = -1) override;

0 commit comments

Comments
 (0)