Skip to content

Commit 6c88c9d

Browse files
committed
[TIR] tir.const extraction
This commit tries to implement an amendment to tir.constant RFC with centralized storage of constant data within the IRModule Please note that data and irmod_storage_idx are not mutual exclisive further more the irmod_storage_idx is valid only immediatly after prim func addition to the mod or after update within the mod. If prim func is out of the the module scope then the index become meangless. irmod_storage_idx also is not used in calculation of hash function of the tir.constant node. Change-Id: I40742ed580468b0252ea3fec02184cba65e20871
1 parent 1f5d783 commit 6c88c9d

File tree

13 files changed

+212
-81
lines changed

13 files changed

+212
-81
lines changed

include/tvm/ir/module.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <tvm/runtime/container/array.h>
3333
#include <tvm/runtime/container/map.h>
3434
#include <tvm/runtime/container/string.h>
35+
#include <tvm/tir/function.h>
3536

3637
#include <string>
3738
#include <unordered_map>
@@ -343,6 +344,9 @@ class IRModuleNode : public Object {
343344
*/
344345
std::unordered_set<String> import_set_;
345346
friend class IRModule;
347+
348+
public:
349+
void ExtractPrimFuncConstants(tir::PrimFunc func);
346350
};
347351

348352
/*!
@@ -351,6 +355,8 @@ class IRModuleNode : public Object {
351355
*/
352356
class IRModule : public ObjectRef {
353357
public:
358+
static constexpr const char* _constants_attrs_key = "Constants";
359+
354360
/*!
355361
* \brief constructor
356362
* \param functions Functions in the module.

include/tvm/tir/stmt.h

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -628,8 +628,14 @@ class AllocateConstNode : public StmtNode {
628628
public:
629629
/*! \brief The buffer variable. */
630630
Var buffer_var;
631-
/*! \brief The data associated to the constant. */
632-
::tvm::runtime::NDArray data;
631+
/*! \brief The optional data associated to the constant.
632+
*/
633+
Optional<runtime::NDArray> data;
634+
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
635+
this is an optional index to indicate the index within
636+
"Constants" attribute, that is a Array<NDArray> of IRModule.
637+
*/
638+
Optional<Integer> irmod_storage_idx;
633639
/*! \brief The type of the buffer. */
634640
DataType dtype;
635641
/*! \brief The extents of the buffer. */
@@ -677,14 +683,17 @@ class AllocateConstNode : public StmtNode {
677683
};
678684

679685
/*!
680-
* \brief Managed reference to AllocateNode.
681-
* \sa AllocateNode
686+
* \brief Managed reference to AllocateConstNode.
687+
* \sa AllocateConstNode
682688
*/
683689
class AllocateConst : public Stmt {
684690
public:
685-
TVM_DLL AllocateConst(Var buffer_var, ::tvm::runtime::NDArray data, DataType dtype,
686-
Array<PrimExpr> extents, Stmt body, Span span = Span());
687-
691+
/* The constructor to create a IRNode with constant data
692+
* depending on the type of ObjectRef, it will either
693+
* create AllocateConstNode with irmod_storage_idx or data
694+
*/
695+
TVM_DLL AllocateConst(Var buffer_var, DataType dtype, Array<PrimExpr> extents,
696+
ObjectRef data_or_idx, Stmt body, Span span = Span());
688697
TVM_DEFINE_OBJECT_REF_METHODS(AllocateConst, Stmt, AllocateConstNode);
689698
};
690699

python/tvm/script/tir/scope_handler.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
# pylint: disable=redefined-builtin, unused-argument, invalid-name, relative-beyond-top-level
1919
from typing import Tuple, Any, Callable, Optional, List, Union, Mapping
2020

21+
import numpy as np
2122
import synr
2223
import tvm.tir
2324
from tvm.runtime import Object
2425
from tvm.ir import Span, Range
2526
from tvm.tir import Stmt, PrimExpr, IterVar, Var, Buffer, BufferRegion, ForKind
26-
import numpy as np
2727

2828
from .node import BufferSlice
2929
from .utils import buffer_slice_to_region
@@ -159,16 +159,15 @@ def setup_buffer_var(
159159

160160
@register
161161
class AllocateConst(WithScopeHandler):
162-
"""With scope handler tir.allocate(data, extents, dtype, condition)"""
162+
"""With scope handler tir.allocate_const(data, extents, dtype, condition)"""
163163

164164
def __init__(self):
165165
def allocate_const(raw_data, dtype, shape, span=None):
166166
list_data = []
167167
for i in raw_data:
168168
list_data.append(i.value)
169169
nd_data = tvm.nd.array(np.asarray(list_data, dtype=dtype))
170-
171-
n = tvm.tir.AllocateConst(self.buffer_var, nd_data, dtype, shape, self.body, span=span)
170+
n = tvm.tir.AllocateConst(self.buffer_var, dtype, shape, nd_data, self.body, span=span)
172171
return n
173172

174173
super().__init__(allocate_const, concise_scope=True, def_symbol=True)
@@ -182,15 +181,17 @@ def enter_scope(
182181
span: synr.ast.Span,
183182
):
184183
# define buffer vars in symbol table
185-
if isinstance(node, ast.With):
184+
if isinstance(node, synr.ast.With):
186185
vars = WithScopeHandler.get_optional_vars(node, context)
187186
if len(vars) != 1:
188-
context.report_error("Unexpected number of vars", node.span)
187+
context.report_error(f"Unexpected number of vars: 1 vs. {len(vars)}", node.span)
189188
name = vars[0].id.name
190189
var_span = vars[0].id.span
191-
elif isinstance(node, ast.Assign):
192-
name = node.lhs.id.name
193-
var_span = node.lhs.id.span
190+
elif isinstance(node, synr.ast.Assign):
191+
if len(node.lhs) != 1:
192+
context.report_error(f"Unexpected number of vars: 1 vs. {len(node.lhs)}", node.span)
193+
name = node.lhs[0].id.name
194+
var_span = node.lhs[0].id.span
194195
else:
195196
raise Exception("Internal Bug")
196197

@@ -214,11 +215,7 @@ def launch_thread(env_var, extent, span):
214215
attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
215216
return tvm.tir.AttrStmt(
216217
IterVar(
217-
(0, extent),
218-
env_var,
219-
getattr(IterVar, "ThreadIndex"),
220-
thread_id,
221-
span=span,
218+
(0, extent), env_var, getattr(IterVar, "ThreadIndex"), thread_id, span=span,
222219
),
223220
attr_key,
224221
extent,
@@ -545,9 +542,7 @@ class Serial(ForScopeHandler):
545542

546543
def __init__(self):
547544
def serial(
548-
begin: PrimExpr,
549-
end: PrimExpr,
550-
annotations: Optional[Mapping[str, Object]] = None,
545+
begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None,
551546
):
552547
self.create_loop_info(begin, end, ForKind.SERIAL, annotations=annotations)
553548

@@ -560,9 +555,7 @@ class Parallel(ForScopeHandler):
560555

561556
def __init__(self):
562557
def parallel(
563-
begin: PrimExpr,
564-
end: PrimExpr,
565-
annotations: Optional[Mapping[str, Object]] = None,
558+
begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None,
566559
):
567560
self.create_loop_info(begin, end, ForKind.PARALLEL, annotations=annotations)
568561

@@ -575,9 +568,7 @@ class Vectorized(ForScopeHandler):
575568

576569
def __init__(self):
577570
def vectorized(
578-
begin: PrimExpr,
579-
end: PrimExpr,
580-
annotations: Optional[Mapping[str, Object]] = None,
571+
begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None,
581572
):
582573
self.create_loop_info(begin, end, ForKind.VECTORIZED, annotations=annotations)
583574

@@ -590,9 +581,7 @@ class Unroll(ForScopeHandler):
590581

591582
def __init__(self):
592583
def unroll(
593-
begin: PrimExpr,
594-
end: PrimExpr,
595-
annotations: Optional[Mapping[str, Object]] = None,
584+
begin: PrimExpr, end: PrimExpr, annotations: Optional[Mapping[str, Object]] = None,
596585
):
597586
self.create_loop_info(begin, end, ForKind.UNROLLED, annotations=annotations)
598587

src/ir/module.cc

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,20 @@
3434
#include <tvm/relay/analysis.h>
3535
#include <tvm/relay/expr_functor.h>
3636
#include <tvm/relay/transform.h>
37+
#include <tvm/tir/stmt.h>
38+
#include <tvm/tir/stmt_functor.h>
3739

3840
#include <fstream>
3941
#include <sstream>
4042
#include <unordered_set>
4143

4244
namespace tvm {
43-
45+
constexpr const char* IRModule::_constants_attrs_key;
4446
IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
4547
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
4648
std::unordered_set<String> import_set, parser::SourceMap source_map,
4749
DictAttrs attrs) {
4850
auto n = make_object<IRModuleNode>();
49-
n->functions = std::move(functions);
5051
n->type_definitions = std::move(type_definitions);
5152
n->global_type_var_map_ = {};
5253
n->global_var_map_ = {};
@@ -55,11 +56,10 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
5556
n->source_map = source_map;
5657
n->attrs = std::move(attrs);
5758

58-
for (const auto& kv : n->functions) {
59-
// set global var map
60-
ICHECK(n->global_var_map_.count(kv.first->name_hint) == 0)
61-
<< "Duplicate global function name " << kv.first->name_hint;
62-
n->global_var_map_.Set(kv.first->name_hint, kv.first);
59+
if (functions.defined()) {
60+
for (const auto& kv : functions) {
61+
n->Add(kv.first, kv.second);
62+
}
6363
}
6464

6565
for (const auto& kv : n->type_definitions) {
@@ -201,6 +201,10 @@ void IRModuleNode::Add(const GlobalVar& var, const BaseFunc& f, bool update) {
201201
WarnIfMalformed(GetRef<IRModule>(this), GetRef<relay::Function>(ptr));
202202
}
203203

204+
if (f->IsInstance<tir::PrimFuncNode>()) {
205+
ExtractPrimFuncConstants(Downcast<tir::PrimFunc>(f));
206+
}
207+
204208
AddUnchecked(var, checked_func);
205209
}
206210

@@ -218,6 +222,63 @@ void IRModuleNode::AddUnchecked(const GlobalVar& var, const BaseFunc& func) {
218222
global_var_map_.Set(var->name_hint, var);
219223
}
220224

225+
// Replaces constant data to index into mod's "Constants" attrs array.
226+
void IRModuleNode::ExtractPrimFuncConstants(tir::PrimFunc func) {
227+
using ConstArrayType = Array<runtime::NDArray>;
228+
class Applicator : public tir::StmtExprVisitor {
229+
protected:
230+
// returns index of the a in constant_array_, if not found - appends
231+
// TODO(@d-smirnov): make real content comparision with already existing NDArrays
232+
// instead of reference comparision
233+
size_t deDup(const runtime::NDArray& a) {
234+
auto it = std::find(constant_array_.begin(), constant_array_.end(), a);
235+
if (it != constant_array_.end()) {
236+
return it - constant_array_.begin();
237+
}
238+
constant_array_.push_back(std::move(a));
239+
return constant_array_.size() - 1;
240+
}
241+
242+
public:
243+
ConstArrayType Apply(tir::Stmt body, const ConstArrayType& constant_array) {
244+
constant_array_ = constant_array;
245+
this->VisitStmt(body);
246+
return constant_array_;
247+
}
248+
249+
void VisitStmt_(const tir::AllocateConstNode* acn) override {
250+
tir::AllocateConstNode* node = const_cast<tir::AllocateConstNode*>(acn);
251+
// Check whether the data already defined within the module's attrs
252+
// and replace it with array index;
253+
ICHECK(node->data) << "data field should be defined";
254+
if (node->data) {
255+
node->irmod_storage_idx = Optional<Integer>(Integer(deDup(node->data.value())));
256+
}
257+
tir::StmtExprVisitor::VisitStmt_(acn);
258+
}
259+
260+
private:
261+
ConstArrayType constant_array_;
262+
};
263+
264+
std::pair<const char*, const ObjectRef> default_value = {IRModule::_constants_attrs_key,
265+
Array<runtime::NDArray>()};
266+
if (attrs.defined()) {
267+
if (!attrs->dict.count(IRModule::_constants_attrs_key))
268+
attrs.CopyOnWrite()->dict.Set(default_value.first, default_value.second);
269+
} else {
270+
Map<String, ObjectRef> dict = {default_value};
271+
attrs = DictAttrs(dict);
272+
}
273+
274+
ConstArrayType constant_array_ =
275+
Downcast<ConstArrayType>(attrs->dict[IRModule::_constants_attrs_key]);
276+
277+
const ConstArrayType constant_list =
278+
Applicator().Apply(func.CopyOnWrite()->body, constant_array_);
279+
attrs.CopyOnWrite()->dict.Set(IRModule::_constants_attrs_key, constant_list);
280+
}
281+
221282
void IRModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
222283
// We hash the global type var name to use as a globally unique prefix for tags.
223284
// The hash will be used as the most significant byte of the tag, with the index of

src/printer/tvmscript_printer.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -899,24 +899,26 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) {
899899

900900
Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
901901
std::stringstream ss;
902+
ICHECK(alloc->data) << "Should be presented";
903+
const auto& data = alloc->data.value();
902904

903905
if (alloc->dtype.is_int()) {
904906
if (alloc->dtype.bits() == 8) {
905-
NDArrayToTIR<int8_t>(alloc->data, ss);
907+
NDArrayToTIR<int8_t>(data, ss);
906908
} else if (alloc->dtype.bits() == 16) {
907-
NDArrayToTIR<int16_t>(alloc->data, ss);
909+
NDArrayToTIR<int16_t>(data, ss);
908910
} else if (alloc->dtype.bits() == 32) {
909-
NDArrayToTIR<int32_t>(alloc->data, ss);
911+
NDArrayToTIR<int32_t>(data, ss);
910912
} else {
911913
LOG(FATAL) << "DataType not supported";
912914
}
913915
} else if (alloc->dtype.is_float()) {
914916
if (alloc->dtype.bits() == 16) {
915-
NDArrayToTIR<int16_t>(alloc->data, ss);
917+
NDArrayToTIR<int16_t>(data, ss);
916918
} else if (alloc->dtype.bits() == 32) {
917-
NDArrayToTIR<float>(alloc->data, ss);
919+
NDArrayToTIR<float>(data, ss);
918920
} else if (alloc->dtype.bits() == 64) {
919-
NDArrayToTIR<double>(alloc->data, ss);
921+
NDArrayToTIR<double>(data, ss);
920922
} else {
921923
LOG(FATAL) << "DataType not supported";
922924
}

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
631631
int shape = kv.second->data->shape[i];
632632
extents.push_back(tir::make_const(DataType::Int(32), shape));
633633
}
634-
body = tir::AllocateConst(buffer_var, kv.second->data, dtype, extents, body);
634+
body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body);
635635
}
636636

637637
// Define the attributes

src/target/llvm/codegen_llvm.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1434,7 +1434,8 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
14341434
}
14351435

14361436
void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) {
1437-
auto array = NDArrayToLLVMArray(ctx_, op->data);
1437+
auto data = op->data.value();
1438+
auto array = NDArrayToLLVMArray(ctx_, data);
14381439
std::string symbol_name = op->buffer_var->name_hint;
14391440
llvm::GlobalVariable* param_symbol = new llvm::GlobalVariable(
14401441
*module_, array->getType(), true, llvm::GlobalValue::InternalLinkage, array, symbol_name);

src/target/source/codegen_c.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,8 @@ void CodeGenC::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs,
705705
void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
706706
std::string symbol_name = op->buffer_var->name_hint;
707707
int64_t num_elements = 1;
708-
for (int64_t dim : op->data.Shape()) {
708+
const auto& data = op->data.value();
709+
for (int64_t dim : data.Shape()) {
709710
num_elements *= dim;
710711
}
711712

@@ -715,11 +716,11 @@ void CodeGenC::VisitStmt_(const AllocateConstNode* op) {
715716
<< "#endif\n"
716717
<< "static const ";
717718

718-
PrintType(op->data.DataType(), decl_stream);
719+
PrintType(data.DataType(), decl_stream);
719720

720721
// Allocate the global static variable
721722
decl_stream << " " << symbol_name << "[" << num_elements << "] = {\n";
722-
NDArrayDataToC(op->data, 4, decl_stream);
723+
NDArrayDataToC(data, 4, decl_stream);
723724

724725
decl_stream << "};\n"
725726
<< "#ifdef __cplusplus\n"

0 commit comments

Comments
 (0)