Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,73 @@ class Allocate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
};

/*!
* \brief Allocate a buffer that can be used in body.
*/
class AllocateConstNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The data associated to the constant. */
::tvm::runtime::NDArray data;
/*! \brief The type of the buffer. */
DataType dtype;
/*! \brief The extents of the buffer. */
Array<PrimExpr> extents;
/*! \brief The body to be executed. */
Stmt body;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("body", &body);
v->Visit("span", &span);
}

bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const {
return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) &&
equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(buffer_var);
hash_reduce(dtype);
hash_reduce(extents);
hash_reduce(body);
hash_reduce(data);
}

/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \return The result.
*/
int32_t constant_allocation_size() const { return constant_allocation_size(extents); }
/*!
* \brief If the buffer size is constant, return the size.
* Otherwise return 0.
* \param extents The extents of the buffer.
* \return The result.
*/
TVM_DLL static int32_t constant_allocation_size(const Array<PrimExpr>& extents);

static constexpr const char* _type_key = "tir.AllocateConst";
TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstNode, StmtNode);
};

/*!
* \brief Managed reference to AllocateNode.
* \sa AllocateNode
*/
class AllocateConst : public Stmt {
public:
TVM_DLL AllocateConst(Var buffer_var, ::tvm::runtime::NDArray data, DataType dtype,
Array<PrimExpr> extents, Stmt body, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
};

/*!
* \brief The container of seq statement.
* Represent a sequence of statements.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const WhileNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateConstNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
Expand All @@ -113,6 +114,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(ForNode);
IR_STMT_FUNCTOR_DISPATCH(WhileNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateNode);
IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode);
IR_STMT_FUNCTOR_DISPATCH(StoreNode);
IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode);
IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode);
Expand Down Expand Up @@ -155,6 +157,7 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const WhileNode* op) override;
void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const AllocateConstNode* op) override;
void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const BufferRealizeNode* op) override;
Expand Down Expand Up @@ -255,6 +258,7 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const WhileNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const AllocateConstNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const BufferRealizeNode* op) override;
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.runtime import Object
from tvm.ir import Span, Range
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
import numpy as np

from .context_maintainer import ContextMaintainer
from .utils import (
Expand Down Expand Up @@ -147,6 +148,52 @@ def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
context.update_symbol(name, self.buffer_var, node)


@register
class AllocateConst(WithScopeHandler):
"""With scope handler tir.allocate(data, extents, dtype, condition)"""

def __init__(self):
def allocate_const(raw_data, dtype, shape, span=None):
list_data = []
for i in raw_data:
list_data.append(i.value)
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))

n = tvm.tir.AllocateConst(self.buffer_var, nd_data, dtype, shape, self.body, span=span)
return n

super().__init__(allocate_const, concise_scope=True, def_symbol=True)
self.buffer_var = None

def enter_scope(
self,
node: synr.ast.Node,
context: ContextMaintainer,
arg_list: List[Any],
span: synr.ast.Span,
):
# define buffer vars in symbol table
if isinstance(node, ast.With):
vars = WithScopeHandler.get_optional_vars(node, context)
if len(vars) != 1:
context.report_error("Unexpected number of vars", node.span)
name = vars[0].id.name
var_span = vars[0].id.span
elif isinstance(node, ast.Assign):
name = node.lhs.id.name
var_span = node.lhs.id.span
else:
raise Exception("Internal Bug")

def setup_buffer_var(data, dtype, shape, span: Span = None):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
context.update_symbol(name, self.buffer_var, node)


@register
class LaunchThread(WithScopeHandler):
"""With scope handler tir.launch_thread(env_var, extent)"""
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@
from .expr import Call, CallEffectKind, Let, IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import (
BufferStore,
BufferRealize,
Store,
ProducerStore,
Allocate,
AllocateConst,
AttrStmt,
)
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,40 @@ def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
)


@tvm._ffi.register_object("tir.AllocateConst")
class AllocateConst(Stmt):
"""Allocate constant node.

Parameters
----------
buffer_var : Var
The buffer variable.

data : NDarray
The data associated with the constant

dtype : str
The data type of the buffer.

extents : list of Expr
The extents of the allocate

condition : PrimExpr
The condition.

body : Stmt
The body statement.

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, buffer_var, dtype, extents, condition, body, span=None):
self.__init_handle_by_constructor__(
_ffi_api.AllocateConst, buffer_var, dtype, extents, condition, body, span
)


@tvm._ffi.register_object("tir.AttrStmt")
class AttrStmt(Stmt):
"""AttrStmt node.
Expand Down
2 changes: 2 additions & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BufferStoreNode* op) override;
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const AllocateConstNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const EvaluateNode* op) override;
Expand Down Expand Up @@ -359,6 +360,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
static Doc PrintConstScalar(DataType dtype, const T& data);
Doc GetUniqueName(std::string prefix);
Doc AllocVar(const Var& var);
Doc AllocConst(const AllocateConst& var);
Doc AllocBuf(const Buffer& buffer);
/*!
* \brief special method to render vectors of docs with a separator
Expand Down
13 changes: 13 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,19 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) {
Doc doc;
doc << "constant(" << Print(op->buffer_var) << ", " << PrintDType(op->dtype) << ", "
<< Print(op->extents) << ")";

if (op->body->IsInstance<SeqStmtNode>()) {
doc << PrintBody(op->body);
} else {
doc << ";" << Doc::NewLine() << Print(op->body);
}
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
Expand Down
63 changes: 63 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const BufferStoreNode* op) override;
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const AllocateConstNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Doc VisitStmt_(const ForNode* op) override;
Expand Down Expand Up @@ -247,6 +248,26 @@ class TVMScriptPrinter : public StmtFunctor<Doc(const Stmt&)>,
}
return doc;
}

/*!
* \brief special method to print NDArray in TIR
* \param arr the NDArray to be printed
* \param os the output stream where the NDArray will be printed to
*/
template <typename T>
void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
int ndim = arr->ndim;
int tot_dim = 1;
for (int i = 0; i < ndim; i++) {
tot_dim *= arr->shape[i];
}
T* data_ptr = reinterpret_cast<T*>(arr->data);
os << "[";
for (int i = 0; i < tot_dim; i++) {
os << data_ptr[i] << ", ";
}
os << "]";
}
};

Doc TVMScriptPrinter::GetUniqueName(std::string prefix) {
Expand Down Expand Up @@ -685,6 +706,48 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
return Doc();
}

Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
std::stringstream ss;

if (alloc->dtype.is_int()) {
if (alloc->dtype.bits() == 8) {
NDArrayToTIR<int8_t>(alloc->data, ss);
} else if (alloc->dtype.bits() == 16) {
NDArrayToTIR<int16_t>(alloc->data, ss);
} else if (alloc->dtype.bits() == 32) {
NDArrayToTIR<int32_t>(alloc->data, ss);
} else {
LOG(FATAL) << "DataType not supported";
}
} else if (alloc->dtype.is_float()) {
if (alloc->dtype.bits() == 16) {
NDArrayToTIR<int16_t>(alloc->data, ss);
} else if (alloc->dtype.bits() == 32) {
NDArrayToTIR<float>(alloc->data, ss);
} else if (alloc->dtype.bits() == 64) {
NDArrayToTIR<double>(alloc->data, ss);
} else {
LOG(FATAL) << "DataType not supported";
}
} else {
LOG(FATAL) << "DataType not supported";
}
auto ndarray_str = ss.str();

Doc doc;
var_not_in_headers.insert(alloc->buffer_var.get());
if (current_num_ != num_child_ - 1) {
doc << "with tir.allocate_const(" << ndarray_str << ", " << PrintDType(alloc->dtype) << ", "
<< Print(alloc->extents) << ")";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(alloc->body));
} else {
doc << Print(alloc->buffer_var) << " = tir.allocate_const(" << ndarray_str << ", "
<< PrintDType(alloc->dtype) << ", " << Print(alloc->extents);
doc << ")" << Doc::NewLine() << PrintBody(alloc->body);
}
return doc;
}

Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
Expand Down
Loading