Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,34 @@ class ScheduleNode : public runtime::Object {
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a loop with a key value pair
* \param loop_rv The loop to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value, a string or a ExprRV
*/
virtual void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) = 0;
/*!
* \brief Annotate a block with a key value pair
* \param block_rv The block to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value, a string or a ExprRV
*/
virtual void Annotate(const BlockRV& block_rv, const String& ann_key,
const ObjectRef& ann_val) = 0;
/*!
* \brief Unannotate a loop's annotation with key ann_key
* \param loop_rv The loop to be unannotated
* \param ann_key The annotation key
*/
virtual void Unannotate(const LoopRV& loop_rv, const String& ann_key) = 0;
/*!
* \brief Unannotate a block's annotation with key ann_key
* \param block_rv The block to be unannotated
* \param ann_key The annotation key
*/
virtual void Unannotate(const BlockRV& block_rv, const String& ann_key) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import synr
import tvm.tir
from tvm.runtime import Object
from tvm.runtime import Object, String
from tvm.ir import Span, Range
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind

Expand Down Expand Up @@ -486,7 +486,7 @@ def create_loop_info(
self.annotations: Mapping[str, Object] = {}
if annotations is not None:
self.annotations = {
key: tvm.tir.StringImm(val) if isinstance(val, str) else val
key: String(val) if isinstance(val, str) else val
for key, val in annotations.items()
}

Expand Down
5 changes: 2 additions & 3 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tvm.ir.expr import PrimExpr, Range

import tvm.tir
from tvm.runtime import Object
from tvm.runtime import Object, String
from tvm import te
from tvm.target import Target
from tvm.ir import Span
Expand Down Expand Up @@ -430,8 +430,7 @@ def block_attr(attrs: Mapping[str, Object], span: Span = None):
span,
)
attrs = {
key: tvm.tir.StringImm(val) if isinstance(val, str) else val
for key, val in attrs.items()
key: String(val) if isinstance(val, str) else val for key, val in attrs.items()
}
block_scope.annotations = attrs

Expand Down
125 changes: 123 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object
from tvm.tir import Block, For, IntImm, PrimFunc
from tvm.runtime import Object, String
from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc

from . import _ffi_api
from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod
Expand Down Expand Up @@ -1735,6 +1735,127 @@ def after_set_scope(

########## Schedule: Annotation ##########

@type_checked
def annotate(
self,
block_or_loop: Union[BlockRV, LoopRV],
ann_key: str,
ann_val: Union[str, int, float, ExprRV],
) -> None:
"""Annotate a block/loop with a key value pair

Parameters
----------
block_or_loop: Union[BlockRV, LoopRV]
The block/loop to be annotated
ann_key : str
The annotation key
ann_val : Union[str, int, float, ExprRV]
The annotation value

Examples
--------

Before annotate, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_annotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do annotate:

.. code-block:: python

sch = tir.Schedule(before_annotate)
sch.annotate(sch.get_block("B"), "ann_key", "ann_value")
print(sch.mod["main"].script())

After applying annotate, the IR becomes:

.. code-block:: python

@T.prim_func
def after_annotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"ann_key", "ann_value"})
B[vi, vj] = A[vi, vj] * 2.0

"""
if isinstance(ann_val, str):
ann_val = String(ann_val)
elif isinstance(ann_val, int):
ann_val = IntImm("int32", ann_val)
elif isinstance(ann_val, float):
ann_val = FloatImm("float32", ann_val)
_ffi_api.ScheduleAnnotate( # type: ignore # pylint: disable=no-member
self, block_or_loop, ann_key, ann_val
)

@type_checked
def unannotate(self, block_or_loop: Union[BlockRV, LoopRV], ann_key: str) -> None:
"""Unannotate a block/loop's annotation with key ann_key

Parameters
----------
block_or_loop: Union[BlockRV, LoopRV]
The block/loop to be unannotated
ann_key : str
The annotation key

Examples
--------

Before unannotate, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_unannotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.block_attr({"ann_key", "ann_value"})
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do annotate:

.. code-block:: python

sch = tir.Schedule(before_unannotate)
sch.unannotate(sch.get_block("B"), "ann_key")
print(sch.mod["main"].script())

After applying unannotate, the IR becomes:

.. code-block:: python

@T.prim_func
def after_unannotate(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
B = T.match_buffer(b, (128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

"""
_ffi_api.ScheduleUnannotate( # type: ignore # pylint: disable=no-member
self, block_or_loop, ann_key
)

########## Schedule: Misc ##########

@type_checked
Expand Down
47 changes: 47 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,53 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/

ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_val) {
if (ann_val.as<StringObj>()) {
return ann_val;
}
if (const auto* expr = ann_val.as<PrimExprNode>()) {
ICHECK(!ann_val->IsInstance<StringImmNode>())
<< "TypeError: runtime::String is expected, but gets StringImm";
return this->Get(GetRef<PrimExpr>(expr));
}
LOG(FATAL)
<< "TypeError: Only strings, integers, floats, ExprRVs and Arrays are supported for now, but "
<< "gets: " << ann_val->GetTypeKey();
throw;
}

void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
const ObjectRef& ann_val) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Annotate(state_, this->GetSRef(loop_rv), ann_key, this->CheckAndGetAnnotationValue(ann_val));
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
}

void ConcreteScheduleNode::Unannotate(const LoopRV& loop_rv, const String& ann_key) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
}

void ConcreteScheduleNode::Annotate(const BlockRV& block_rv, const String& ann_key,
const ObjectRef& ann_val) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Annotate(state_, this->GetSRef(block_rv), ann_key,
this->CheckAndGetAnnotationValue(ann_val));
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("annotate", this->error_render_level_);
}

void ConcreteScheduleNode::Unannotate(const BlockRV& loop_rv, const String& ann_key) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Unannotate(state_, this->GetSRef(loop_rv), ann_key);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("unannotate", this->error_render_level_);
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ class ConcreteScheduleNode : public ScheduleNode {
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
void Annotate(const BlockRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const BlockRV& loop_rv, const String& ann_key) override;

/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down Expand Up @@ -162,6 +167,13 @@ class ConcreteScheduleNode : public ScheduleNode {
inline Array<ExprRV> CreateRV(const std::vector<int64_t>& value);
/*! \brief Remove a random variable from the symbol table */
inline void RemoveFromSymbolTable(const ObjectRef& rv);
/*!
* \brief Check the annotation value is valid and look up the random variable. Raises an exception
* if the type of the annotation value is not allowed.
* \param The annotation value.
* \return The annotation value with random variables substituted with their values.
*/
ObjectRef CheckAndGetAnnotationValue(const ObjectRef& ann_val);
};

// implementations
Expand Down
17 changes: 17 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,23 @@ TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/*!
* \brief Annotate a block/loop with a key value pair
* \param self The state of the schedule
* \param sref The block/loop sref to be annotated
* \param ann_key The annotation key
* \param ann_val The annotation value
*/
TVM_DLL void Annotate(ScheduleState self, const StmtSRef& sref, const String& ann_key,
const ObjectRef& ann_val);
/*!
* \brief Unannotate a block/loop's annotation with key ann_key
* \param self The state of the schedule
* \param sref The block/loop to be unannotated
* \param ann_key The annotation key
*/
TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& ann_key);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading