Skip to content

Commit ccecc58

Browse files
cblmemovinx13junrushaoSiyuan FengMasterJH5574
committed
Introducing MemHammer
Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Hongyi Jin <[email protected]>
1 parent 6fa88e3 commit ccecc58

File tree

21 files changed

+3687
-0
lines changed

21 files changed

+3687
-0
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,11 @@ class ScheduleNode : public runtime::Object {
466466
*/
467467
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
468468
BufferIndexType buffer_index_type) = 0;
469+
/******** Schedule: Data movement ********/
470+
virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
471+
const String& storage_scope) = 0;
472+
virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
473+
const String& storage_scope) = 0;
469474
/******** Schedule: Compute location ********/
470475
/*!
471476
* \brief Move a producer block under the specific loop, and regenerate the

include/tvm/tir/stmt.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,37 @@ constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layo
16091609
*/
16101610
constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";
16111611

1612+
/*!
1613+
* \brief Mark that the block need to add predicate for block var bounds during lowering
1614+
*/
1615+
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";
1616+
1617+
/*! \brief Mark that tensor core is enabled in the PrimExpr */
1618+
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";
1619+
1620+
/*!
1621+
* \brief Mark a block as generated by cache_read or cache_write block.
1622+
* 0 means cache_read; 1 means cache_write.
1623+
* \sa meta_schedule_cache_type_read
1624+
* \sa meta_schedule_cache_type_write
1625+
*/
1626+
constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";
1627+
1628+
/*! \sa meta_schedule_cache_type */
1629+
constexpr const int meta_schedule_cache_type_read = 0;
1630+
1631+
/*! \sa meta_schedule_cache_type */
1632+
constexpr const int meta_schedule_cache_type_write = 1;
1633+
1634+
/*! \brief Mark auto copy for memhammer */
1635+
constexpr const char* auto_copy = "auto_copy";
1636+
1637+
/*! \brief Mark local stage constraint on data copy */
1638+
constexpr const char* local_stage = "local_stage";
1639+
1640+
/*! \brief Mark vectorization length constraint on block */
1641+
constexpr const char* vector_bytes = "vector_bytes";
1642+
16121643
/*!
16131644
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
16141645
* warp size.

include/tvm/tir/transform.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,12 @@ TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
647647
*/
648648
TVM_DLL Pass ExtractPrimFuncConstants();
649649

650+
/*!
651+
* \brief Automatically do memory optimizations for auto copy blocks
652+
* \return The pass.
653+
*/
654+
TVM_DLL Pass LowerAutoCopy();
655+
650656
/*!
651657
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
652658
* \return The pass.

python/tvm/tir/schedule/schedule.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,30 @@ def after_reindex(
16661666
self, block, buffer_index, buffer_index_type_enum
16671667
)
16681668

1669+
########## Schedule: Data movement ##########
1670+
1671+
def read_at(
1672+
self,
1673+
loop: LoopRV,
1674+
block: BlockRV,
1675+
read_buffer_index: int,
1676+
storage_scope: str,
1677+
) -> BlockRV:
1678+
return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member
1679+
self, loop, block, read_buffer_index, storage_scope
1680+
)
1681+
1682+
def write_at(
1683+
self,
1684+
loop: LoopRV,
1685+
block: BlockRV,
1686+
write_buffer_index: int,
1687+
storage_scope: str,
1688+
) -> BlockRV:
1689+
return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member
1690+
self, loop, block, write_buffer_index, storage_scope
1691+
)
1692+
16691693
########## Schedule: Compute location ##########
16701694

16711695
@type_checked

python/tvm/tir/transform/transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,17 @@ def ExtractPrimFuncConstants():
926926
return _ffi_api.ExtractPrimFuncConstants() # type: ignore
927927

928928

929+
def LowerAutoCopy():
930+
"""Automatically do memory optimizations for auto copy blocks
931+
932+
Returns
933+
-------
934+
fpass : tvm.transform.Pass
935+
The result pass
936+
"""
937+
return _ffi_api.LowerAutoCopy() # type: ignore
938+
939+
929940
def RenormalizeSplitPattern():
930941
"""Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
931942

src/driver/driver_api.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
214214
pass_list.push_back(tir::transform::UnifyThreadBinding());
215215
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
216216
pass_list.push_back(tir::transform::CompactBufferAllocation());
217+
pass_list.push_back(tir::transform::LowerAutoCopy());
217218
pass_list.push_back(tir::transform::LowerMatchBuffer());
218219
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
219220
pass_list.push_back(tir::transform::LowerOpaqueBlock());

src/meta_schedule/feature_extractor/per_store_feature.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ Sequential PassListForPerStoreFeature() {
309309
tir::transform::ConvertBlocksToOpaque(),
310310
tir::transform::UnifyThreadBinding(),
311311
tir::transform::CompactBufferAllocation(),
312+
tir::transform::LowerAutoCopy(),
312313
tir::transform::LowerMatchBuffer(),
313314
tir::transform::Simplify(),
314315
});

src/meta_schedule/postproc/verify_gpu_code.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
164164
pass_list.push_back(tir::transform::UnifyThreadBinding());
165165
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
166166
pass_list.push_back(tir::transform::CompactBufferAllocation());
167+
pass_list.push_back(tir::transform::LowerAutoCopy());
167168
pass_list.push_back(tir::transform::LowerMatchBuffer());
168169
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
169170
pass_list.push_back(tir::transform::LowerOpaqueBlock());

src/tir/schedule/concrete_schedule.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,30 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
629629
return CreateRV<BlockRV>(result);
630630
}
631631

632+
/******** Schedule: Data movement ********/
633+
634+
BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv,
635+
int read_buffer_index, const String& storage_scope) {
636+
StmtSRef result{nullptr};
637+
TVM_TIR_SCHEDULE_BEGIN();
638+
result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index,
639+
storage_scope);
640+
TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_);
641+
this->state_->DebugVerify();
642+
return CreateRV<BlockRV>(result);
643+
}
644+
645+
BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv,
646+
int write_buffer_index, const String& storage_scope) {
647+
StmtSRef result{nullptr};
648+
TVM_TIR_SCHEDULE_BEGIN();
649+
result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index,
650+
storage_scope);
651+
TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_);
652+
this->state_->DebugVerify();
653+
return CreateRV<BlockRV>(result);
654+
}
655+
632656
/******** Schedule: Compute location ********/
633657

634658
void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,

src/tir/schedule/concrete_schedule.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ class ConcreteScheduleNode : public ScheduleNode {
126126
int cse_thresh) override;
127127
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
128128
BufferIndexType buffer_index_type) override;
129+
/******** Schedule: Data movement ********/
130+
BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
131+
const String& storage_scope) override;
132+
BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
133+
const String& storage_scope) override;
129134
/******** Schedule: Compute location ********/
130135
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
131136
int index = -1) override;

0 commit comments

Comments
 (0)