@@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
4949
5050TVM_REGISTER_NODE_TYPE (HybridOpNode);
5151
52- int HybridOpNode::num_outputs () const { return static_cast <int >(outputs .size ()); }
52+ int HybridOpNode::num_outputs () const { return static_cast <int >(symbolic_outputs .size ()); }
5353
5454Array<IterVar> HybridOpNode::root_iter_vars () const { return this ->axis ; }
5555
56- DataType HybridOpNode::output_dtype (size_t i) const { return outputs [i]->dtype ; }
56+ DataType HybridOpNode::output_dtype (size_t i) const { return symbolic_outputs [i]->dtype ; }
5757
58- Array<PrimExpr> HybridOpNode::output_shape (size_t i) const { return outputs [i]->shape ; }
58+ Array<PrimExpr> HybridOpNode::output_shape (size_t i) const { return symbolic_outputs [i]->shape ; }
5959
6060HybridOp::HybridOp (std::string name, std::string tag, Map<String, ObjectRef> attrs,
6161 Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
@@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> att
6767 n->tag = std::move (tag);
6868 n->attrs = std::move (attrs);
6969 n->inputs = std::move (inputs);
70- n->outputs = std::move (outputs);
70+ n->symbolic_outputs = std::move (outputs);
7171 n->axis = te::GatherLoopVars (body);
7272 n->body = std::move (body);
7373 data_ = std::move (n);
@@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self,
104104 ICHECK_EQ (self.operator ->(), this );
105105 auto n = make_object<HybridOpNode>(*this );
106106 n->body = te::ReplaceTensor (this ->body , rmap);
107+ n->outputs = Array<Tensor>();
107108 for (size_t i = 0 ; i < n->inputs .size (); ++i) {
108109 Tensor t = n->inputs [i];
109110 if (rmap.count (t)) {
@@ -166,7 +167,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage,
166167 Stmt ret = AttrStmt (make_zero (DataType::Int (32 )), tir::attr::extern_scope, 0 , this ->body );
167168 std::unordered_map<Tensor, Tensor> rmap;
168169 for (int i = 0 ; i < this ->num_outputs (); ++i) {
169- rmap[outputs [i]] = stage->op .output (i);
170+ rmap[symbolic_outputs [i]] = stage->op .output (i);
170171 }
171172 auto n = make_object<HybridOpNode>(*this );
172173 /* This is a story little bit complicated.
0 commit comments