Skip to content

Commit 5f1da9c

Browse files
Giuseppe Rossinigiuseros
authored andcommitted
Integrate tir constant nodes in compilation pipeline
This PR integrates tir.allocate_const to the compilation pipeline to support --link-params. Change-Id: Ic8d0cb75d596299fcae7078b304598afbf0c5494 Co-authored-by: Giuseppe Rossini <[email protected]> Change-Id: Id98cc682bbfacfe75c4d8b260fd41658f1f196b2
1 parent 7458e2d commit 5f1da9c

23 files changed

+416
-183
lines changed

include/tvm/tir/function.h

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -151,42 +151,6 @@ class PrimFunc : public BaseFunc {
151151
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
152152
};
153153

154-
/*!
155-
* \brief Describes one parameter that should be linked into the generated module.
156-
*
157-
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
158-
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
159-
* use the information contained in this node to include the parameter data in the generated
160-
* module.
161-
*/
162-
class LinkedParamNode : public Object {
163-
public:
164-
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
165-
int64_t id;
166-
167-
/*! \brief Parameter data which should get linked into the final module. */
168-
::tvm::runtime::NDArray param;
169-
170-
void VisitAttrs(tvm::AttrVisitor* v) {
171-
v->Visit("id", &id);
172-
v->Visit("param", &param);
173-
}
174-
175-
static constexpr const char* _type_key = "tir.LinkedParam";
176-
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
177-
};
178-
179-
/*!
180-
* \brief Managed reference to LinkedParamNode.
181-
*/
182-
class LinkedParam : public ObjectRef {
183-
public:
184-
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
185-
186-
TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
187-
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
188-
};
189-
190154
/*!
191155
* \brief Specialize parameters of PrimFunc.
192156
* \param func The PrimFunc to be specialized.

include/tvm/tir/stmt.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,42 @@ class Allocate : public Stmt {
575575
TVM_DEFINE_OBJECT_REF_METHODS(Allocate, Stmt, AllocateNode);
576576
};
577577

578+
/*!
579+
* \brief Describes one parameter that should be linked into the generated module.
580+
*
581+
* When parameters are to be linked in with generated code (i.e. on target_host-compatible
582+
* backends), Relay attaches instances of this object to a global TIR function. Code-generators
583+
* use the information contained in this node to include the parameter data in the generated
584+
* module.
585+
*/
586+
class LinkedParamNode : public Object {
587+
public:
588+
/*! \brief Unique numeric identifier used by runtimes to lookup this parameter. */
589+
int64_t id;
590+
591+
/*! \brief Parameter data which should get linked into the final module. */
592+
::tvm::runtime::NDArray param;
593+
594+
void VisitAttrs(tvm::AttrVisitor* v) {
595+
v->Visit("id", &id);
596+
v->Visit("param", &param);
597+
}
598+
599+
static constexpr const char* _type_key = "tir.LinkedParam";
600+
TVM_DECLARE_FINAL_OBJECT_INFO(LinkedParamNode, Object);
601+
};
602+
603+
/*!
604+
* \brief Managed reference to LinkedParamNode.
605+
*/
606+
class LinkedParam : public ObjectRef {
607+
public:
608+
TVM_DLL LinkedParam(int64_t id, ::tvm::runtime::NDArray param);
609+
610+
TVM_DEFINE_OBJECT_REF_METHODS(LinkedParam, ObjectRef, LinkedParamNode);
611+
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
612+
};
613+
578614
/*!
579615
* \brief Allocate a buffer that can be used in body.
580616
*/

include/tvm/tir/transform.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
#define TVM_TIR_TRANSFORM_H_
2626

2727
#include <tvm/ir/transform.h>
28+
#include <tvm/relay/expr.h>
2829
#include <tvm/tir/expr.h>
2930
#include <tvm/tir/function.h>
3031

3132
#include <string>
33+
#include <vector>
3234

3335
namespace tvm {
3436
namespace tir {
@@ -431,6 +433,8 @@ TVM_DLL Pass LegalizePackedCalls();
431433
*/
432434
TVM_DLL Pass FlattenBuffer();
433435

436+
TVM_DLL Pass BindParams(const std::vector<const relay::ConstantNode*>& constants);
437+
434438
} // namespace transform
435439
} // namespace tir
436440
} // namespace tvm

python/tvm/te/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from .tensor import TensorSlice, Tensor
3232
from .tensor_intrin import decl_tensor_intrin
3333
from .tag import tag_scope
34-
from .operation import placeholder, compute, scan, extern, var, size_var
34+
from .operation import placeholder, compute, scan, extern, var, size_var, const
3535
from .operation import thread_axis, reduce_axis
3636
from .operation import create_prim_func
3737

python/tvm/te/operation.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,28 @@ def var(name="tindex", dtype="int32", span=None):
351351
return tvm.tir.Var(name, dtype, span)
352352

353353

354+
def const(name="tindex", dtype="int32", span=None):
355+
"""Create a new constant with specified name and dtype
356+
357+
Parameters
358+
----------
359+
name : str
360+
The name
361+
362+
dtype : str
363+
The data type
364+
365+
span : Optional[Span]
366+
The location of this variable in the source.
367+
368+
Returns
369+
-------
370+
var : Var
371+
The result symbolic variable.
372+
"""
373+
return tvm.tir.Const(name, dtype, span)
374+
375+
354376
def size_var(name="size", dtype="int32", span=None):
355377
"""Create a new variable represents a tensor shape size, which is non-negative.
356378

src/relay/backend/aot_executor_codegen.cc

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -540,22 +540,20 @@ class AOTExecutorCodegen : public ExprVisitor {
540540

541541
void VisitExpr_(const ConstantNode* op) override {
542542
Expr expr = GetRef<Expr>(op);
543-
size_t index = params_.size();
544-
std::string name = "p" + std::to_string(index);
545543
StorageInfo& sinfo = storage_device_map_[expr];
546-
param_storage_ids_[name] = sinfo->storage_ids[0];
547-
params_[name] = op->data;
548-
params_by_expr_.Set(expr, name);
544+
std::stringstream ss;
545+
ss << "constant_" << constant_map_.size();
546+
547+
tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype))));
548+
constant_map_[constant.operator->()] = op;
549549

550550
// If the Constant node is an output node we need to copy the content of the parameter to the
551551
// output A Var node can only produce a single output
552552
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
553553
if (output_iter != return_sid_.end()) {
554554
int output_index = std::distance(return_sid_.begin(), output_iter);
555-
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
556-
{tir::StringImm(params_by_expr_[expr])});
557-
CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false,
558-
sinfo->storage_sizes_in_bytes[0]);
555+
CopyToOutput(main_signature_[input_vars_.size() + output_index], constant,
556+
/* pack_input */ false, sinfo->storage_sizes_in_bytes[0]);
559557
}
560558
}
561559

@@ -632,6 +630,20 @@ class AOTExecutorCodegen : public ExprVisitor {
632630
}
633631
}
634632

633+
for (auto kv : constant_map_) {
634+
auto buffer_var = GetRef<tir::Var>(kv.first);
635+
auto dtype = DataType(kv.second->data->dtype);
636+
637+
int ndim = kv.second->data->ndim;
638+
Array<PrimExpr> extents;
639+
640+
for (int i = 0; i < ndim; i++) {
641+
int shape = kv.second->data->shape[i];
642+
extents.push_back(tir::make_const(DataType::Int(32), shape));
643+
}
644+
body = tir::AllocateConst(buffer_var, kv.second->data, dtype, extents, body);
645+
}
646+
635647
// Define the attributes
636648
body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_type, 1, body);
637649
body = tir::AttrStmt(PrimExpr(), tvm::tir::attr::device_id, 0, body);
@@ -680,6 +692,7 @@ class AOTExecutorCodegen : public ExprVisitor {
680692
Map<Expr, String> params_by_expr_;
681693
/*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/
682694
std::unordered_map<std::string, int64_t> param_storage_ids_;
695+
std::unordered_map<const tir::VarNode*, const ConstantNode*> constant_map_;
683696

684697
/*! \brief plan memory of device result */
685698
StorageMap storage_device_map_;
@@ -783,6 +796,7 @@ class AOTExecutorCodegen : public ExprVisitor {
783796
} else {
784797
ret.lowered_funcs.Set(target_host_str, mod_run);
785798
}
799+
786800
ret.function_metadata = std::move(function_metadata_);
787801
ret.metadata = runtime::Metadata(input_vars_.size(), return_sid_.size(),
788802
runtime::kTvmExecutorAot, mod_name);

src/relay/backend/build_module.cc

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,13 @@ class RelayBuildModule : public runtime::ModuleNode {
392392
}
393393

394394
// Fuse the operations if it is needed.
395-
relay_module = transform::FuseOps()(relay_module);
395+
if (targets.size() == 1) {
396+
const auto& it = targets.begin();
397+
With<Target> tctx((*it).second);
398+
relay_module = transform::FuseOps()(relay_module);
399+
} else {
400+
relay_module = transform::FuseOps()(relay_module);
401+
}
396402

397403
// Do layout rewrite for auto-scheduler.
398404
if (backend::IsAutoSchedulerEnabled() && targets.size() == 1) {
@@ -540,28 +546,6 @@ class RelayBuildModule : public runtime::ModuleNode {
540546

541547
auto lowered_funcs = executor_codegen_->GetIRModule();
542548

543-
// Generate a placeholder function that attaches linked params as its arguments.
544-
if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {
545-
CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen.";
546-
auto param_ids = executor_codegen_->GetParamIds();
547-
auto link_params = Map<String, tir::LinkedParam>();
548-
for (auto param : ret_.params) {
549-
link_params.Set(param.first, tir::LinkedParam(param_ids[param.first], param.second));
550-
}
551-
552-
Map<String, ObjectRef> dict;
553-
dict.Set(tvm::tir::attr::kLinkedParams, link_params);
554-
dict.Set(tvm::attr::kGlobalSymbol, String(::tvm::runtime::symbol::tvm_lookup_linked_param));
555-
DictAttrs attrs{dict};
556-
auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
557-
Map<tir::Var, tir::Buffer>(), attrs);
558-
if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) {
559-
lowered_funcs.Set(target_host->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
560-
}
561-
lowered_funcs[target_host->str()]->Add(
562-
GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim);
563-
}
564-
565549
// When there is no lowered_funcs due to reasons such as optimization.
566550
if (lowered_funcs.size() == 0) {
567551
if (target_host.defined() && target_host->kind->name == "llvm") {

src/relay/backend/compile_engine.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <tvm/te/operation.h>
3636
#include <tvm/te/schedule.h>
3737
#include <tvm/te/schedule_pass.h>
38+
#include <tvm/tir/transform.h>
3839
#include <tvm/topi/tags.h>
3940

4041
#include <functional>
@@ -221,10 +222,17 @@ class CompileEngineImpl : public CompileEngineNode {
221222
for (te::Tensor arg : cfunc->outputs) {
222223
all_args.push_back(arg);
223224
}
225+
std::vector<const ConstantNode*> all_consts;
226+
for (auto kv : cfunc->constant_tensors) {
227+
all_args.push_back(kv.second);
228+
all_consts.push_back(kv.first);
229+
}
230+
224231
// lower the function
225232
std::unordered_map<te::Tensor, tir::Buffer> binds;
226233
auto func_name = cfunc->prim_fn_var->name_hint;
227234
cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
235+
cfunc->funcs->Update(tir::transform::BindParams(all_consts)(cfunc->funcs));
228236
value->cached_func = cfunc;
229237

230238
return value;

src/relay/backend/te_compiler.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <tvm/te/operation.h>
3434
#include <tvm/te/schedule.h>
3535
#include <tvm/te/schedule_pass.h>
36+
#include <tvm/tir/transform.h>
3637
#include <tvm/topi/tags.h>
3738

3839
#include <functional>
@@ -222,10 +223,16 @@ class TECompilerImpl : public TECompilerNode {
222223
for (te::Tensor arg : cfunc->outputs) {
223224
all_args.push_back(arg);
224225
}
226+
std::vector<const ConstantNode*> all_consts;
227+
for (auto kv : cfunc->constant_tensors) {
228+
all_args.push_back(kv.second);
229+
all_consts.push_back(kv.first);
230+
}
225231

226232
std::unordered_map<te::Tensor, tir::Buffer> binds;
227233
auto func_name = cfunc->prim_fn_var->name_hint;
228234
cfunc->funcs->Update(tvm::LowerSchedule(cfunc->schedule, all_args, func_name, binds));
235+
cfunc->funcs->Update(tir::transform::BindParams(all_consts)(cfunc->funcs));
229236
value->cached_func = cfunc;
230237
return value;
231238
}

0 commit comments

Comments
 (0)