Skip to content

Commit b80cd66

Browse files
Giuseppe Rossinigiuseros
authored andcommitted
[TIR] Introduce tir.allocate_const to TIR
This PR is adding non-scalar constant representation in TIR. This is used to express constants (i.e., parameters) in the TIR instead of bypassing the TIR as it's done until now. Change-Id: Id3afc4d7197260cb43ecde60f05ccbce3fc42430 Co-authored-by: Giuseppe Rossini <[email protected]> Change-Id: Id4a09a637c9c1fd7d49989c6c10f474a78569e18
1 parent c8f54f9 commit b80cd66

File tree

11 files changed

+345
-1
lines changed

11 files changed

+345
-1
lines changed

include/tvm/tir/stmt.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,73 @@ class Allocate : public Stmt {
570570
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
571571
};
572572

573+
/*!
574+
* \brief Allocate a buffer that can be used in body.
575+
*/
576+
class AllocateConstNode : public StmtNode {
577+
public:
578+
/*! \brief The buffer variable. */
579+
Var buffer_var;
580+
/*! \brief The data associated to the constant. */
581+
::tvm::runtime::NDArray data;
582+
/*! \brief The type of the buffer. */
583+
DataType dtype;
584+
/*! \brief The extents of the buffer. */
585+
Array<PrimExpr> extents;
586+
/*! \brief The body to be executed. */
587+
Stmt body;
588+
589+
void VisitAttrs(AttrVisitor* v) {
590+
v->Visit("buffer_var", &buffer_var);
591+
v->Visit("dtype", &dtype);
592+
v->Visit("extents", &extents);
593+
v->Visit("body", &body);
594+
v->Visit("span", &span);
595+
}
596+
597+
bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
598+
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
599+
equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body);
600+
}
601+
602+
void SHashReduce(SHashReducer hash_reduce) const {
603+
hash_reduce.DefHash(buffer_var);
604+
hash_reduce(dtype);
605+
hash_reduce(extents);
606+
hash_reduce(body);
607+
hash_reduce(data);
608+
}
609+
610+
/*!
611+
* \brief If the buffer size is constant, return the size.
612+
* Otherwise return 0.
613+
* \return The result.
614+
*/
615+
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
616+
/*!
617+
* \brief If the buffer size is constant, return the size.
618+
* Otherwise return 0.
619+
* \param extents The extents of the buffer.
620+
* \return The result.
621+
*/
622+
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
623+
624+
static constexpr const char* _type_key = "tir.AllocateConst";
625+
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
626+
};
627+
628+
/*!
629+
* \brief Managed reference to AllocateNode.
630+
* \sa AllocateNode
631+
*/
632+
class AllocateConst : public Stmt {
633+
public:
634+
TVM_DLL AllocateConst(Var buffer_var, ::tvm::runtime::NDArray data, DataType dtype, Array<PrimExpr> extents,
635+
Stmt body, Span span = Span());
636+
637+
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
638+
};
639+
573640
/*!
574641
* \brief The container of seq statement.
575642
* Represent a sequence of statements.

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
8787
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8888
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8989
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90+
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9091
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9192
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9293
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -113,6 +114,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
113114
IR_STMT_FUNCTOR_DISPATCH(ForNode);
114115
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
115116
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
117+
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
116118
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
117119
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
118120
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
@@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
155157
void VisitStmt_(const ForNode* op) override;
156158
void VisitStmt_(const WhileNode* op) override;
157159
void VisitStmt_(const AllocateNode* op) override;
160+
void VisitStmt_(const AllocateConstNode* op) override;
158161
void VisitStmt_(const StoreNode* op) override;
159162
void VisitStmt_(const BufferStoreNode* op) override;
160163
void VisitStmt_(const BufferRealizeNode* op) override;
@@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
255258
Stmt VisitStmt_(const ForNode* op) override;
256259
Stmt VisitStmt_(const WhileNode* op) override;
257260
Stmt VisitStmt_(const AllocateNode* op) override;
261+
Stmt VisitStmt_(const AllocateConstNode* op) override;
258262
Stmt VisitStmt_(const StoreNode* op) override;
259263
Stmt VisitStmt_(const BufferStoreNode* op) override;
260264
Stmt VisitStmt_(const BufferRealizeNode* op) override;

python/tvm/script/scope_handler.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
from .registry import register
3636
from .node import BufferSlice
3737

38+
import numpy as np
39+
3840

3941
class ScopeHandler:
4042
"""Base class for all scope handlers"""
@@ -147,6 +149,52 @@ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
147149
context.update_symbol(name, self.buffer_var, node)
148150

149151

152+
@register
153+
class AllocateConst(WithScopeHandler):
154+
"""With scope handler tir.allocate(data, extents, dtype, condition)"""
155+
156+
def __init__(self):
157+
def allocate_const(raw_data, dtype, shape, span=None):
158+
list_data = []
159+
for i in raw_data:
160+
list_data.append(i.value)
161+
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
162+
163+
n = tvm.tir.AllocateConst(self.buffer_var, nd_data, dtype, shape, self.body, span=span)
164+
return n
165+
166+
super().__init__(allocate_const, concise_scope=True, def_symbol=True)
167+
self.buffer_var = None
168+
169+
def enter_scope(
170+
self,
171+
node: synr.ast.Node,
172+
context: ContextMaintainer,
173+
arg_list: List[Any],
174+
span: synr.ast.Span,
175+
):
176+
# define buffer vars in symbol table
177+
if isinstance(node, ast.With):
178+
vars = WithScopeHandler.get_optional_vars(node, context)
179+
if len(vars) != 1:
180+
context.report_error("Unexpected number of vars", node.span)
181+
name = vars[0].id.name
182+
var_span = vars[0].id.span
183+
elif isinstance(node, ast.Assign):
184+
name = node.lhs.id.name
185+
var_span = node.lhs.id.span
186+
else:
187+
raise Exception("Internal Bug")
188+
189+
def setup_buffer_var(data, dtype, shape, span: Span = None):
190+
"""Setup buffer var for a given type."""
191+
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
192+
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
193+
194+
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
195+
context.update_symbol(name, self.buffer_var, node)
196+
197+
150198
@register
151199
class LaunchThread(WithScopeHandler):
152200
"""With scope handler tir.launch_thread(env_var, extent)"""

python/tvm/tir/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,15 @@
2828
from .expr import Call, CallEffectKind, Let, IterVar, Any
2929

3030
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
31-
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
31+
from .stmt import (
32+
BufferStore,
33+
BufferRealize,
34+
Store,
35+
ProducerStore,
36+
Allocate,
37+
AllocateConst,
38+
AttrStmt,
39+
)
3240
from .stmt import ProducerRealize, SeqStmt
3341
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
3442
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

python/tvm/tir/stmt.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,40 @@ def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
320320
)
321321

322322

323+
@tvm._ffi.register_object("tir.AllocateConst")
324+
class AllocateConst(Stmt):
325+
"""Allocate constant node.
326+
327+
Parameters
328+
----------
329+
buffer_var : Var
330+
The buffer variable.
331+
332+
data : NDarray
333+
The data associated with the constant
334+
335+
dtype : str
336+
The data type of the buffer.
337+
338+
extents : list of Expr
339+
The extents of the allocate
340+
341+
condition : PrimExpr
342+
The condition.
343+
344+
body : Stmt
345+
The body statement.
346+
347+
span : Optional[Span]
348+
The location of this itervar in the source code.
349+
"""
350+
351+
def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
352+
self.__init_handle_by_constructor__(
353+
_ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, span
354+
)
355+
356+
323357
@tvm._ffi.register_object("tir.AttrStmt")
324358
class AttrStmt(Stmt):
325359
"""AttrStmt node.

src/printer/text_printer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
314314
Doc VisitStmt_(const BufferStoreNode* op) override;
315315
Doc VisitStmt_(const BufferRealizeNode* op) override;
316316
Doc VisitStmt_(const AllocateNode* op) override;
317+
Doc VisitStmt_(const AllocateConstNode* op) override;
317318
Doc VisitStmt_(const IfThenElseNode* op) override;
318319
Doc VisitStmt_(const SeqStmtNode* op) override;
319320
Doc VisitStmt_(const EvaluateNode* op) override;
@@ -351,6 +352,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
351352
static Doc PrintConstScalar(DataType dtype, const T& data);
352353
Doc GetUniqueName(std::string prefix);
353354
Doc AllocVar(const Var& var);
355+
Doc AllocConst(const AllocateConst& var);
354356
Doc AllocBuf(const Buffer& buffer);
355357
/*!
356358
* \brief special method to render vectors of docs with a separator

src/printer/tir_text_printer.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
460460
return doc;
461461
}
462462

463+
Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
464+
Doc doc;
465+
doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
466+
<< Print(op->extents) << ")";
467+
468+
if (op->body->IsInstance<SeqStmtNode>()) {
469+
doc << PrintBody(op->body);
470+
} else {
471+
doc << ";" << Doc::NewLine() << Print(op->body);
472+
}
473+
return doc;
474+
}
475+
463476
Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
464477
Doc doc;
465478
doc << "if " << Print(op->condition) << PrintBody(op->then_case);

src/printer/tvmscript_printer.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
133133
Doc VisitStmt_(const BufferStoreNode* op) override;
134134
Doc VisitStmt_(const BufferRealizeNode* op) override;
135135
Doc VisitStmt_(const AllocateNode* op) override;
136+
Doc VisitStmt_(const AllocateConstNode* op) override;
136137
Doc VisitStmt_(const IfThenElseNode* op) override;
137138
Doc VisitStmt_(const SeqStmtNode* op) override;
138139
Doc VisitStmt_(const ForNode* op) override;
@@ -246,6 +247,26 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
246247
}
247248
return doc;
248249
}
250+
251+
/*!
252+
* \brief special method to print NDArray in TIR
253+
* \param arr the NDArray to be printed
254+
* \param os the output stream where the NDArray will be printed to
255+
*/
256+
template <typename T>
257+
void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
258+
int ndim = arr->ndim;
259+
int tot_dim = 1;
260+
for (int i = 0; i < ndim; i++) {
261+
tot_dim *= arr->shape[i];
262+
}
263+
T* data_ptr = reinterpret_cast<T*>(arr->data);
264+
os << "[";
265+
for (int i = 0; i < tot_dim; i++) {
266+
os << data_ptr[i] << ", ";
267+
}
268+
os << "]";
269+
}
249270
};
250271

251272
Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
@@ -684,6 +705,48 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
684705
return Doc();
685706
}
686707

708+
Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
709+
std::stringstream ss;
710+
711+
if (alloc->dtype.is_int()) {
712+
if (alloc->dtype.bits() == 8) {
713+
NDArrayToTIR<int8_t>(alloc->data, ss);
714+
} else if (alloc->dtype.bits() == 16) {
715+
NDArrayToTIR<int16_t>(alloc->data, ss);
716+
} else if (alloc->dtype.bits() == 32) {
717+
NDArrayToTIR<int32_t>(alloc->data, ss);
718+
} else {
719+
LOG(FATAL) << "DataType not supported";
720+
}
721+
} else if (alloc->dtype.is_float()) {
722+
if (alloc->dtype.bits() == 16) {
723+
NDArrayToTIR<int16_t>(alloc->data, ss);
724+
} else if (alloc->dtype.bits() == 32) {
725+
NDArrayToTIR<float>(alloc->data, ss);
726+
} else if (alloc->dtype.bits() == 64) {
727+
NDArrayToTIR<double>(alloc->data, ss);
728+
} else {
729+
LOG(FATAL) << "DataType not supported";
730+
}
731+
} else {
732+
LOG(FATAL) << "DataType not supported";
733+
}
734+
auto ndarray_str = ss.str();
735+
736+
Doc doc;
737+
var_not_in_headers.insert(alloc->buffer_var.get());
738+
if (current_num_ != num_child_ - 1) {
739+
doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", "
740+
<< Print(alloc->extents) << ")";
741+
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body));
742+
} else {
743+
doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", "
744+
<< PrintDType(alloc->dtype) << ", " << Print(alloc->extents);
745+
doc << ")" << Doc::NewLine() << PrintBody(alloc->body);
746+
}
747+
return doc;
748+
}
749+
687750
Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
688751
Doc doc;
689752
doc << "if " << Print(op->condition) << ":";

0 commit comments

Comments
 (0)