@@ -540,22 +540,20 @@ class AOTExecutorCodegen : public ExprVisitor {
540540
541541 void VisitExpr_ (const ConstantNode* op) override {
542542 Expr expr = GetRef<Expr>(op);
543- size_t index = params_.size ();
544- std::string name = " p" + std::to_string (index);
545543 StorageInfo& sinfo = storage_device_map_[expr];
546- param_storage_ids_[name] = sinfo->storage_ids [0 ];
547- params_[name] = op->data ;
548- params_by_expr_.Set (expr, name);
544+ std::stringstream ss;
545+ ss << " constant_" << constant_map_.size ();
546+
547+ tir::Var constant (ss.str (), PointerType (PrimType (DataType (op->data ->dtype ))));
548+ constant_map_[constant.operator ->()] = op;
549549
550550 // If the Constant node is an output node we need to copy the content of the parameter to the
551551 // output A Var node can only produce a single output
552552 auto output_iter = std::find (return_sid_.begin (), return_sid_.end (), sinfo->storage_ids [0 ]);
553553 if (output_iter != return_sid_.end ()) {
554554 int output_index = std::distance (return_sid_.begin (), output_iter);
555- auto param_handle = tvm::tir::Call (DataType::Handle (), tvm::tir::builtin::lookup_param (),
556- {tir::StringImm (params_by_expr_[expr])});
557- CopyToOutput (main_signature_[input_vars_.size () + output_index], param_handle, false ,
558- sinfo->storage_sizes_in_bytes [0 ]);
555+ CopyToOutput (main_signature_[input_vars_.size () + output_index], constant,
556+ /* pack_input */ false , sinfo->storage_sizes_in_bytes [0 ]);
559557 }
560558 }
561559
@@ -632,6 +630,20 @@ class AOTExecutorCodegen : public ExprVisitor {
632630 }
633631 }
634632
633+ for (auto kv : constant_map_) {
634+ auto buffer_var = GetRef<tir::Var>(kv.first );
635+ auto dtype = DataType (kv.second ->data ->dtype );
636+
637+ int ndim = kv.second ->data ->ndim ;
638+ Array<PrimExpr> extents;
639+
640+ for (int i = 0 ; i < ndim; i++) {
641+ int shape = kv.second ->data ->shape [i];
642+ extents.push_back (tir::make_const (DataType::Int (32 ), shape));
643+ }
644+ body = tir::AllocateConst (buffer_var, kv.second ->data , dtype, extents, body);
645+ }
646+
635647 // Define the attributes
636648 body = tir::AttrStmt (PrimExpr (), tvm::tir::attr::device_type, 1 , body);
637649 body = tir::AttrStmt (PrimExpr (), tvm::tir::attr::device_id, 0 , body);
@@ -680,6 +692,7 @@ class AOTExecutorCodegen : public ExprVisitor {
680692 Map<Expr, String> params_by_expr_;
681693 /* ! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/
682694 std::unordered_map<std::string, int64_t > param_storage_ids_;
695+ std::unordered_map<const tir::VarNode*, const ConstantNode*> constant_map_;
683696
684697 /* ! \brief plan memory of device result */
685698 StorageMap storage_device_map_;
@@ -783,6 +796,7 @@ class AOTExecutorCodegen : public ExprVisitor {
783796 } else {
784797 ret.lowered_funcs .Set (target_host_str, mod_run);
785798 }
799+
786800 ret.function_metadata = std::move (function_metadata_);
787801 ret.metadata = runtime::Metadata (input_vars_.size (), return_sid_.size (),
788802 runtime::kTvmExecutorAot , mod_name);
0 commit comments