Skip to content

Commit c7970dd

Browse files
authored
[TensorIR] New schedule primitive set_dtype (#14316)
# Motivation Currently, we miss a schedule primitive to change the data type of allocated buffer (e.g. via `cache_read`/`cache_write`), and thus we cannot perform type conversion while loading data from global to shared memory. This PR adds a new schedule primitive `set_dtype` that follows the interface of `set_scope` and allows users to customize the allocated buffers' data type. # Example Before running `set_dtype`: ```python @T.prim_func def before_set_dtype( A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") ) -> None: B = T.alloc_buffer((128, 128), dtype="float32") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j] C[vi, vj] = B[vi, vj] + 1.0 ``` then we perform the `set_dtype` schedule: ```python sch = tir.Schedule(before_set_dtype) sch.set_dtype("B", buffer_index=0, dtype="float16") print(sch.mod["main"].script()) ``` we get transformed code: ```python @T.prim_func def after_set_dtype( A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32") ) -> None: B = T.alloc_buffer((128, 128), dtype="float16") for i, j in T.grid(128, 128): with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16") for i, j in T.grid(128, 128): with T.block("C"): vi, vj = T.axis.remap("SS", [i, j] C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0 ``` where data type conversions are inserted automatically. # Other Usage Using the combination of `cache_read` + `set_dtype` can help us load data from the memory hierarchy while converting data to the desired type.
1 parent 0c2dd47 commit c7970dd

File tree

12 files changed

+385
-5
lines changed

12 files changed

+385
-5
lines changed

include/tvm/tir/schedule/schedule.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,13 +589,23 @@ class ScheduleNode : public runtime::Object {
589589
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
590590
int offset) = 0;
591591
/*!
592-
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
592+
* \brief Set the storage scope of a buffer, where the buffer is specified by a block and a
593593
* write-index
594594
* \param block_rv The producer block of the buffer
595595
* \param buffer_index The index of the buffer in block's write region
596596
* \param storage_scope The storage scope to be set
597597
*/
598598
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
599+
/*!
600+
* \brief Set the data type of a buffer, where the buffer is specified by a block and a
601+
* write-index
602+
* \note This schedule primitive is unsafe and may change correctness of program because of
603+
* type conversion, please use with caution.
604+
* \param block_rv The producer block of the buffer
605+
* \param buffer_index the index of the buffer in block's write region
606+
* \param dtype The data type to be set
607+
*/
608+
virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0;
599609
/******** Schedule: Blockize & Tensorize ********/
600610
/*!
601611
* \brief Convert the subtree rooted at a specific loop into a block.

python/tvm/tir/schedule/schedule.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2369,7 +2369,7 @@ def set_scope(
23692369
self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer], storage_scope: str
23702370
) -> None:
23712371
"""Set the storage scope of a buffer, where the buffer is
2372-
specified by the a block and a write-index
2372+
specified by the a block and a write-index.
23732373
23742374
Parameters
23752375
----------
@@ -2431,7 +2431,7 @@ def after_set_scope(
24312431
24322432
Note
24332433
----
2434-
Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
2434+
`set_scope` requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
24352435
"""
24362436
block = self._normalize_block_arg(block)
24372437
if not isinstance(buffer_index, int):
@@ -2442,6 +2442,81 @@ def after_set_scope(
24422442
self, block, buffer_index, storage_scope
24432443
)
24442444

2445+
@type_checked
2446+
def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None:
2447+
"""Set the data type of a buffer, where the buffer is
2448+
specified by the a block and write-index.
2449+
2450+
This schedule primitive is unsafe and may change the correctness of program because of
2451+
type conversion, please use with caution.
2452+
2453+
Parameters
2454+
----------
2455+
block : Union[BlockRV, str]
2456+
The producer block of the buffer
2457+
buffer_index : int
2458+
The index of the buffer in block's write region
2459+
dtype : str
2460+
The data type to be set
2461+
2462+
Examples
2463+
--------
2464+
2465+
Before set_dtype, in TensorIR, the IR is:
2466+
2467+
.. code-block:: python
2468+
2469+
@T.prim_func
2470+
def before_set_dtype(
2471+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
2472+
) -> None:
2473+
B = T.alloc_buffer((128, 128), dtype="float32")
2474+
2475+
for i, j in T.grid(128, 128):
2476+
with T.block("B"):
2477+
vi, vj = T.axis.remap("SS", [i, j])
2478+
B[vi, vj] = A[vi, vj] * 2.0
2479+
for i, j in T.grid(128, 128):
2480+
with T.block("C"):
2481+
vi, vj = T.axis.remap("SS", [i, j]
2482+
C[vi, vj] = B[vi, vj] + 1.0
2483+
2484+
Create the schedule and do set_dtype:
2485+
2486+
.. code-block:: python
2487+
2488+
sch = tir.Schedule(before_set_dtype)
2489+
sch.set_dtype("B", buffer_index=0, dtype="float16")
2490+
print(sch.mod["main"].script())
2491+
2492+
After applying set_dtype, the IR becomes:
2493+
2494+
.. code-block:: python
2495+
2496+
@T.prim_func
2497+
def after_set_dtype(
2498+
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
2499+
) -> None:
2500+
B = T.alloc_buffer((128, 128), dtype="float16")
2501+
2502+
for i, j in T.grid(128, 128):
2503+
with T.block("B"):
2504+
vi, vj = T.axis.remap("SS", [i, j])
2505+
B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
2506+
for i, j in T.grid(128, 128):
2507+
with T.block("C"):
2508+
vi, vj = T.axis.remap("SS", [i, j]
2509+
C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
2510+
2511+
Note
2512+
----
2513+
`set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
2514+
"""
2515+
block = self._normalize_block_arg(block)
2516+
_ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint: disable=no-member
2517+
self, block, buffer_index, dtype
2518+
)
2519+
24452520
########## Schedule: Blockize & Tensorize ##########
24462521

24472522
@type_checked

src/tir/schedule/concrete_schedule.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
727727
this->state_->DebugVerify();
728728
}
729729

730+
void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index,
731+
const String& dtype) {
732+
TVM_TIR_SCHEDULE_BEGIN();
733+
tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype);
734+
TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_);
735+
this->state_->DebugVerify();
736+
}
737+
730738
/******** Schedule: Reduction ********/
731739

732740
BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {

src/tir/schedule/concrete_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ class ConcreteScheduleNode : public ScheduleNode {
146146
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
147147
int offset) override;
148148
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
149+
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override;
149150
/******** Schedule: Blockize & Tensorize ********/
150151
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
151152
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;

src/tir/schedule/primitive.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,18 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
479479
*/
480480
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
481481
const String& storage_scope);
482+
/*!
483+
* \brief Set the data type of a buffer, where the buffer is specified by a block and a
484+
* write-index
485+
* \note This schedule primitive is unsafe and may change correctness of program because of
486+
* type conversion, please use with caution.
487+
* \param self The state of the schedule
488+
* \param block_sref The sref of the producer block of the buffer
489+
* \param buffer_index The index of the buffer in block's write region
490+
* \param dtype The data type to be set
491+
*/
492+
TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
493+
const String& dtype);
482494
/*!
483495
* \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
484496
* or write index

src/tir/schedule/primitive/block_annotate.cc

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19+
#include <tvm/tir/expr.h>
20+
1921
#include "../utils.h"
2022

2123
namespace tvm {
@@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
297299
self->Replace(alloc_site_sref, new_block, block_reuse_map);
298300
}
299301

302+
/*!
303+
* \brief A helper mutator which recursively mutates the old buffer's data type, inserts data type
304+
* conversions, and collecte the block sref reuse information for the following replacement.
305+
*/
306+
class DTypeMutator : private ReplaceBufferMutator {
307+
public:
308+
/*!
309+
* \param allocate_site The block where `old_buffer` was allocated.
310+
* \param old_buffer The old buffer
311+
* \param target_dtype The data type to be set
312+
* \param block_sref_reuse The block sref reuse map to be updated
313+
* \return The new block after the mutation
314+
*/
315+
static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype,
316+
Map<Block, Block>* block_sref_reuse) {
317+
Buffer new_buffer = WithDType(old_buffer, dtype);
318+
DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse);
319+
Stmt new_block = mutator.VisitStmt(allocate_site);
320+
return Downcast<Block>(new_block);
321+
}
322+
323+
private:
324+
DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype,
325+
Map<Block, Block>* block_sref_reuse)
326+
: ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse),
327+
src_dtype_(old_buffer->dtype),
328+
tgt_dtype_(dtype) {}
329+
330+
MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
331+
auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
332+
if (it != buffer_var_map_.end()) {
333+
Buffer new_target_buffer = WithDType(match_buffer->buffer, it->second->dtype);
334+
buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
335+
return MatchBufferRegion(new_target_buffer,
336+
BufferRegion(it->second, match_buffer->source->region));
337+
} else {
338+
return match_buffer;
339+
}
340+
}
341+
342+
Stmt VisitStmt_(const BufferStoreNode* op) final {
343+
BufferStore node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
344+
auto it = buffer_var_map_.find(node->buffer->data.get());
345+
if (it != buffer_var_map_.end()) {
346+
node.CopyOnWrite()->buffer = it->second;
347+
node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value);
348+
}
349+
return node;
350+
}
351+
352+
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
353+
BufferLoad node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
354+
auto it = buffer_var_map_.find(node->buffer->data.get());
355+
if (it != buffer_var_map_.end()) {
356+
return Cast(src_dtype_, BufferLoad(it->second, node->indices));
357+
}
358+
return node;
359+
}
360+
361+
DataType src_dtype_, tgt_dtype_;
362+
};
363+
364+
void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
365+
const String& dtype) {
366+
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
367+
Buffer buffer =
368+
GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);
369+
DataType target_dtype(runtime::String2DLDataType(dtype));
370+
371+
// Step 1. If `dtype` equals the original data type, just return.
372+
if (buffer->dtype == target_dtype) {
373+
return;
374+
}
375+
376+
// Step 2. Get the allocation site of the target buffer.
377+
StmtSRef alloc_site_sref =
378+
NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
379+
const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);
380+
381+
// Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given
382+
// dtype, and insert data type conversions.
383+
Map<Block, Block> block_reuse_map;
384+
Block new_block =
385+
DTypeMutator::Mutate(GetRef<Block>(alloc_site), buffer, target_dtype, &block_reuse_map);
386+
self->Replace(alloc_site_sref, new_block, block_reuse_map);
387+
}
388+
300389
/******** InstructionKind Registration ********/
301390

302391
struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
@@ -356,8 +445,36 @@ struct SetScopeTraits : public UnpackedInstTraits<SetScopeTraits> {
356445
friend struct ::tvm::tir::UnpackedInstTraits;
357446
};
358447

448+
struct UnsafeSetDTypeTraits : public UnpackedInstTraits<UnsafeSetDTypeTraits> {
449+
static constexpr const char* kName = "UnsafeSetDType";
450+
static constexpr bool kIsPure = false;
451+
452+
private:
453+
static constexpr size_t kNumInputs = 1;
454+
static constexpr size_t kNumAttrs = 2;
455+
static constexpr size_t kNumDecisions = 0;
456+
457+
static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
458+
String dtype) {
459+
return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype);
460+
}
461+
462+
static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
463+
String dtype) {
464+
PythonAPICall py("unsafe_set_dtype");
465+
py.Input("block", block_rv);
466+
py.Input("buffer_index", buffer_index);
467+
py.Input("dtype", dtype);
468+
return py.Str();
469+
}
470+
471+
template <typename>
472+
friend struct ::tvm::tir::UnpackedInstTraits;
473+
};
474+
359475
TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
360476
TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits);
477+
TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits);
361478

362479
} // namespace tir
363480
} // namespace tvm

src/tir/schedule/schedule.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
217217
.set_body_method<Schedule>(&ScheduleNode::StorageAlign);
218218
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
219219
.set_body_method<Schedule>(&ScheduleNode::SetScope);
220+
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
221+
.set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
220222
/******** (FFI) Blockize & Tensorize ********/
221223
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
222224
.set_body_method<Schedule>(&ScheduleNode::Blockize);

src/tir/schedule/traced_schedule.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
504504
/*outputs=*/{}));
505505
}
506506

507+
void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index,
508+
const String& dtype) {
509+
ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype);
510+
static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType");
511+
trace_->Append(/*inst=*/Instruction(
512+
/*kind=*/kind,
513+
/*inputs=*/{block_rv},
514+
/*attrs=*/{Integer(buffer_index), dtype},
515+
/*outputs=*/{}));
516+
}
517+
507518
/******** Schedule: Blockize & Tensorize ********/
508519

509520
BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {

src/tir/schedule/traced_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
105105
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
106106
int offset) final;
107107
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final;
108+
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final;
108109
/******** Schedule: Blockize & Tensorize ********/
109110
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
110111
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final;

src/tir/schedule/transform.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) {
4343
return Buffer(new_buffer);
4444
}
4545

46+
Buffer WithDType(const Buffer& buffer, const DataType& dtype) {
47+
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
48+
new_buffer->dtype = dtype;
49+
const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
50+
new_buffer->data =
51+
Var(buffer->data->name_hint, PointerType(PrimType(dtype), ptr_type->storage_scope));
52+
new_buffer->name = buffer->name;
53+
return Buffer(new_buffer);
54+
}
55+
4656
Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source,
4757
const Buffer& target) {
4858
regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion {

0 commit comments

Comments
 (0)