Skip to content

Commit 7458e2d

Browse files
Giuseppe Rossinigiuseros
authored andcommitted
[TIR] Introduce tir.allocate_const to TIR
This PR is adding non-scalar constant representation in TIR. This is used to express constants (i.e., parameters) in the TIR instead of bypassing the TIR as it's done until now. Change-Id: Id3afc4d7197260cb43ecde60f05ccbce3fc42430 Co-authored-by: Giuseppe Rossini <[email protected]> Change-Id: Id4a09a637c9c1fd7d49989c6c10f474a78569e18
1 parent 4b2ccde commit 7458e2d

File tree

11 files changed

+344
-1
lines changed

11 files changed

+344
-1
lines changed

include/tvm/tir/stmt.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,73 @@ class Allocate : public Stmt {
575575
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
576576
};
577577

578+
/*!
579+
* \brief Allocate a buffer that can be used in body.
580+
*/
581+
class AllocateConstNode : public StmtNode {
582+
public:
583+
/*! \brief The buffer variable. */
584+
Var buffer_var;
585+
/*! \brief The data associated to the constant. */
586+
::tvm::runtime::NDArray data;
587+
/*! \brief The type of the buffer. */
588+
DataType dtype;
589+
/*! \brief The extents of the buffer. */
590+
Array<PrimExpr> extents;
591+
/*! \brief The body to be executed. */
592+
Stmt body;
593+
594+
void VisitAttrs(AttrVisitor* v) {
595+
v->Visit("buffer_var", &buffer_var);
596+
v->Visit("dtype", &dtype);
597+
v->Visit("extents", &extents);
598+
v->Visit("body", &body);
599+
v->Visit("span", &span);
600+
}
601+
602+
bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
603+
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
604+
equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body);
605+
}
606+
607+
void SHashReduce(SHashReducer hash_reduce) const {
608+
hash_reduce.DefHash(buffer_var);
609+
hash_reduce(dtype);
610+
hash_reduce(extents);
611+
hash_reduce(body);
612+
hash_reduce(data);
613+
}
614+
615+
/*!
616+
* \brief If the buffer size is constant, return the size.
617+
* Otherwise return 0.
618+
* \return The result.
619+
*/
620+
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
621+
/*!
622+
* \brief If the buffer size is constant, return the size.
623+
* Otherwise return 0.
624+
* \param extents The extents of the buffer.
625+
* \return The result.
626+
*/
627+
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);
628+
629+
static constexpr const char* _type_key = "tir.AllocateConst";
630+
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
631+
};
632+
633+
/*!
634+
* \brief Managed reference to AllocateNode.
635+
* \sa AllocateNode
636+
*/
637+
class AllocateConst : public Stmt {
638+
public:
639+
TVM_DLL AllocateConst(Var buffer_var, ::tvm::runtime::NDArray data, DataType dtype,
640+
Array<PrimExpr> extents, Stmt body, Span span = Span());
641+
642+
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
643+
};
644+
578645
/*!
579646
* \brief The container of seq statement.
580647
* Represent a sequence of statements.

include/tvm/tir/stmt_functor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
8787
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8888
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
8989
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
90+
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9091
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9192
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
9293
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
@@ -113,6 +114,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
113114
IR_STMT_FUNCTOR_DISPATCH(ForNode);
114115
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
115116
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
117+
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
116118
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
117119
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
118120
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
@@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
155157
void VisitStmt_(const ForNode* op) override;
156158
void VisitStmt_(const WhileNode* op) override;
157159
void VisitStmt_(const AllocateNode* op) override;
160+
void VisitStmt_(const AllocateConstNode* op) override;
158161
void VisitStmt_(const StoreNode* op) override;
159162
void VisitStmt_(const BufferStoreNode* op) override;
160163
void VisitStmt_(const BufferRealizeNode* op) override;
@@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
255258
Stmt VisitStmt_(const ForNode* op) override;
256259
Stmt VisitStmt_(const WhileNode* op) override;
257260
Stmt VisitStmt_(const AllocateNode* op) override;
261+
Stmt VisitStmt_(const AllocateConstNode* op) override;
258262
Stmt VisitStmt_(const StoreNode* op) override;
259263
Stmt VisitStmt_(const BufferStoreNode* op) override;
260264
Stmt VisitStmt_(const BufferRealizeNode* op) override;

python/tvm/script/scope_handler.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tvm.runtime import Object
2525
from tvm.ir import Span, Range
2626
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
27+
import numpy as np
2728

2829
from .context_maintainer import ContextMaintainer
2930
from .utils import (
@@ -147,6 +148,52 @@ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
147148
context.update_symbol(name, self.buffer_var, node)
148149

149150

151+
@register
152+
class AllocateConst(WithScopeHandler):
153+
"""With scope handler tir.allocate(data, extents, dtype, condition)"""
154+
155+
def __init__(self):
156+
def allocate_const(raw_data, dtype, shape, span=None):
157+
list_data = []
158+
for i in raw_data:
159+
list_data.append(i.value)
160+
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
161+
162+
n = tvm.tir.AllocateConst(self.buffer_var, nd_data, dtype, shape, self.body, span=span)
163+
return n
164+
165+
super().__init__(allocate_const, concise_scope=True, def_symbol=True)
166+
self.buffer_var = None
167+
168+
def enter_scope(
169+
self,
170+
node: synr.ast.Node,
171+
context: ContextMaintainer,
172+
arg_list: List[Any],
173+
span: synr.ast.Span,
174+
):
175+
# define buffer vars in symbol table
176+
if isinstance(node, ast.With):
177+
vars = WithScopeHandler.get_optional_vars(node, context)
178+
if len(vars) != 1:
179+
context.report_error("Unexpected number of vars", node.span)
180+
name = vars[0].id.name
181+
var_span = vars[0].id.span
182+
elif isinstance(node, ast.Assign):
183+
name = node.lhs.id.name
184+
var_span = node.lhs.id.span
185+
else:
186+
raise Exception("Internal Bug")
187+
188+
def setup_buffer_var(data, dtype, shape, span: Span = None):
189+
"""Setup buffer var for a given type."""
190+
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
191+
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)
192+
193+
setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
194+
context.update_symbol(name, self.buffer_var, node)
195+
196+
150197
@register
151198
class LaunchThread(WithScopeHandler):
152199
"""With scope handler tir.launch_thread(env_var, extent)"""

python/tvm/tir/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,15 @@
2828
from .expr import Call, CallEffectKind, Let, IterVar, Any
2929

3030
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
31-
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
31+
from .stmt import (
32+
BufferStore,
33+
BufferRealize,
34+
Store,
35+
ProducerStore,
36+
Allocate,
37+
AllocateConst,
38+
AttrStmt,
39+
)
3240
from .stmt import ProducerRealize, SeqStmt
3341
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
3442
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

python/tvm/tir/stmt.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,40 @@ def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
328328
)
329329

330330

331+
@tvm._ffi.register_object("tir.AllocateConst")
332+
class AllocateConst(Stmt):
333+
"""Allocate constant node.
334+
335+
Parameters
336+
----------
337+
buffer_var : Var
338+
The buffer variable.
339+
340+
data : NDarray
341+
The data associated with the constant
342+
343+
dtype : str
344+
The data type of the buffer.
345+
346+
extents : list of Expr
347+
The extents of the allocate
348+
349+
condition : PrimExpr
350+
The condition.
351+
352+
body : Stmt
353+
The body statement.
354+
355+
span : Optional[Span]
356+
The location of this itervar in the source code.
357+
"""
358+
359+
def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
360+
self.__init_handle_by_constructor__(
361+
_ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, span
362+
)
363+
364+
331365
@tvm._ffi.register_object("tir.AttrStmt")
332366
class AttrStmt(Stmt):
333367
"""AttrStmt node.

src/printer/text_printer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
322322
Doc VisitStmt_(const BufferStoreNode* op) override;
323323
Doc VisitStmt_(const BufferRealizeNode* op) override;
324324
Doc VisitStmt_(const AllocateNode* op) override;
325+
Doc VisitStmt_(const AllocateConstNode* op) override;
325326
Doc VisitStmt_(const IfThenElseNode* op) override;
326327
Doc VisitStmt_(const SeqStmtNode* op) override;
327328
Doc VisitStmt_(const EvaluateNode* op) override;
@@ -359,6 +360,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
359360
static Doc PrintConstScalar(DataType dtype, const T& data);
360361
Doc GetUniqueName(std::string prefix);
361362
Doc AllocVar(const Var& var);
363+
Doc AllocConst(const AllocateConst& var);
362364
Doc AllocBuf(const Buffer& buffer);
363365
/*!
364366
* \brief special method to render vectors of docs with a separator

src/printer/tir_text_printer.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
460460
return doc;
461461
}
462462

463+
Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
464+
Doc doc;
465+
doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
466+
<< Print(op->extents) << ")";
467+
468+
if (op->body->IsInstance<SeqStmtNode>()) {
469+
doc << PrintBody(op->body);
470+
} else {
471+
doc << ";" << Doc::NewLine() << Print(op->body);
472+
}
473+
return doc;
474+
}
475+
463476
Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
464477
Doc doc;
465478
doc << "if " << Print(op->condition) << PrintBody(op->then_case);

src/printer/tvmscript_printer.cc

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
134134
Doc VisitStmt_(const BufferStoreNode* op) override;
135135
Doc VisitStmt_(const BufferRealizeNode* op) override;
136136
Doc VisitStmt_(const AllocateNode* op) override;
137+
Doc VisitStmt_(const AllocateConstNode* op) override;
137138
Doc VisitStmt_(const IfThenElseNode* op) override;
138139
Doc VisitStmt_(const SeqStmtNode* op) override;
139140
Doc VisitStmt_(const ForNode* op) override;
@@ -247,6 +248,26 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
247248
}
248249
return doc;
249250
}
251+
252+
/*!
253+
* \brief special method to print NDArray in TIR
254+
* \param arr the NDArray to be printed
255+
* \param os the output stream where the NDArray will be printed to
256+
*/
257+
template <typename T>
258+
void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
259+
int ndim = arr->ndim;
260+
int tot_dim = 1;
261+
for (int i = 0; i < ndim; i++) {
262+
tot_dim *= arr->shape[i];
263+
}
264+
T* data_ptr = reinterpret_cast<T*>(arr->data);
265+
os << "[";
266+
for (int i = 0; i < tot_dim; i++) {
267+
os << data_ptr[i] << ", ";
268+
}
269+
os << "]";
270+
}
250271
};
251272

252273
Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
@@ -685,6 +706,48 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
685706
return Doc();
686707
}
687708

709+
Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
710+
std::stringstream ss;
711+
712+
if (alloc->dtype.is_int()) {
713+
if (alloc->dtype.bits() == 8) {
714+
NDArrayToTIR<int8_t>(alloc->data, ss);
715+
} else if (alloc->dtype.bits() == 16) {
716+
NDArrayToTIR<int16_t>(alloc->data, ss);
717+
} else if (alloc->dtype.bits() == 32) {
718+
NDArrayToTIR<int32_t>(alloc->data, ss);
719+
} else {
720+
LOG(FATAL) << "DataType not supported";
721+
}
722+
} else if (alloc->dtype.is_float()) {
723+
if (alloc->dtype.bits() == 16) {
724+
NDArrayToTIR<int16_t>(alloc->data, ss);
725+
} else if (alloc->dtype.bits() == 32) {
726+
NDArrayToTIR<float>(alloc->data, ss);
727+
} else if (alloc->dtype.bits() == 64) {
728+
NDArrayToTIR<double>(alloc->data, ss);
729+
} else {
730+
LOG(FATAL) << "DataType not supported";
731+
}
732+
} else {
733+
LOG(FATAL) << "DataType not supported";
734+
}
735+
auto ndarray_str = ss.str();
736+
737+
Doc doc;
738+
var_not_in_headers.insert(alloc->buffer_var.get());
739+
if (current_num_ != num_child_ - 1) {
740+
doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", "
741+
<< Print(alloc->extents) << ")";
742+
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body));
743+
} else {
744+
doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", "
745+
<< PrintDType(alloc->dtype) << ", " << Print(alloc->extents);
746+
doc << ")" << Doc::NewLine() << PrintBody(alloc->body);
747+
}
748+
return doc;
749+
}
750+
688751
Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
689752
Doc doc;
690753
doc << "if " << Print(op->condition) << ":";

0 commit comments

Comments
 (0)