From 755cc6b8fef1e00de15bebf52f246853b20865da Mon Sep 17 00:00:00 2001 From: "Colin Y. Li" Date: Fri, 31 Dec 2021 14:24:21 +0800 Subject: [PATCH 1/4] [TE][Fix] Comparison of the output tensor --- include/tvm/te/operation.h | 7 ++++--- python/tvm/te/tensor.py | 5 +++-- src/te/operation/scan_op.cc | 1 + src/te/tensor.cc | 22 +++++++++++++++------- tests/python/unittest/test_te_tensor.py | 4 ++++ 5 files changed, 27 insertions(+), 12 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 13f39317dbe4..9afb11ac3123 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -59,8 +59,11 @@ class TVM_DLL OperationNode : public Object { std::string name; /*! \brief optional tag of the operation */ std::string tag; - /*! \brief additional attributes of the operation*/ + /*! \brief additional attributes of the operation */ Map attrs; + /*! \brief output tensors */ + Array outputs; + // virtual destructor. virtual ~OperationNode() {} /*! \return number of outputs */ @@ -472,8 +475,6 @@ class HybridOpNode : public OperationNode { public: /*! \brief The input tensors */ Array inputs; - /*! \brief Symbolic placeholder representation of outputs */ - Array outputs; /*! \brief The axis of iterations */ Array axis; /*! \brief the statement that generates the computation. This is diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index fc85d830c91a..b5609375b8aa 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -86,11 +86,12 @@ def __eq__(self, other): if isinstance(other, _expr.ExprOp): return _expr.EqualOp(self, other) return False + if self.same_as(other): + return True if self.ndim == 0 and other.ndim == 0: raise ValueError( "Equal == comparison among rank-0 tensor is ambiguous, " - "use Tensor.equal for content expression equvalence, " - "use Tensor.same_as for exact reference comparison" + "use Tensor.equal for content expression equvalence." ) return _ffi_api.TensorEqual(self, other) diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 39689bd9654a..ff809f5495f8 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -40,6 +40,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(ScanOpNode); int ScanOpNode::num_outputs() const { return static_cast(update.size()); } + Array ScanOpNode::root_iter_vars() const { Array ret{scan_axis}; for (IterVar iv : spatial_axis_) { diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 1d75761216f1..7a8bcb590b1e 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -78,13 +78,21 @@ String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } -Tensor Operation::output(size_t i) const { - auto node = make_object(); - node->op = *this; - node->value_index = i; - node->dtype = (*this)->output_dtype(i); - node->shape = (*this)->output_shape(i); - return Tensor(node); +Tensor Operation::output(size_t n) const { + if ((*this)->outputs.empty()) { + auto* ptr = static_cast(get_mutable()); + size_t num = static_cast((*this)->num_outputs()); + for (size_t i = 0; i < num; ++i) { + auto node = make_object(); + node->op = *this; + node->value_index = i; + node->dtype = (*this)->output_dtype(i); + node->shape = (*this)->output_shape(i); + ptr->outputs.push_back(Tensor(node)); + } + } + ICHECK_LT(n, (*this)->outputs.size()); + return (*this)->outputs[n]; } Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 6958888e9bb6..6b09410af6c4 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -37,6 +37,8 @@ def test_tensor(): assert T.op.output(0).__hash__() == T.__hash__() d = {T.op.output(0): 1} assert d[T] == 1 + assert T == T.op.output(0) + assert T.same_as(T.op.output(0)) assert T[0][0][0].astype("float16").dtype == "float16" @@ -49,6 +51,8 @@ def test_rank_zero(): print(T) print(T.op.body) assert tuple(T.shape) == () + assert T == T.op.output(0) + assert T.same_as(T.op.output(0)) def test_conv1d(): From 9a2ca41f16b8a6767c1dcedb2ae78f3793a7240d Mon Sep 17 00:00:00 2001 From: "Colin Y. Li" Date: Thu, 17 Feb 2022 16:07:30 +0800 Subject: [PATCH 2/4] fix hybrid op issue --- include/tvm/te/operation.h | 4 +++- src/te/operation/hybrid_op.cc | 10 +++++----- src/te/schedule/schedule_dataflow_rewrite.cc | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 9afb11ac3123..06fd4ab94922 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -475,6 +475,8 @@ class HybridOpNode : public OperationNode { public: /*! \brief The input tensors */ Array inputs; + /*! \brief Symbolic placeholder representation of outputs */ + Array symbolic_outputs; /*! \brief The axis of iterations */ Array axis; /*! \brief the statement that generates the computation. This is @@ -509,7 +511,7 @@ class HybridOpNode : public OperationNode { v->Visit("tag", &tag); v->Visit("attrs", &attrs); v->Visit("inputs", &inputs); - v->Visit("outputs", &outputs); + v->Visit("symbolic_outputs", &symbolic_outputs); v->Visit("axis", &axis); v->Visit("body", &body); } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 5d2412abb3d2..85edcd603c84 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(HybridOpNode); -int HybridOpNode::num_outputs() const { return static_cast(outputs.size()); } +int HybridOpNode::num_outputs() const { return static_cast(symbolic_outputs.size()); } Array HybridOpNode::root_iter_vars() const { return this->axis; } -DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } +DataType HybridOpNode::output_dtype(size_t i) const { return symbolic_outputs[i]->dtype; } -Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } +Array HybridOpNode::output_shape(size_t i) const { return symbolic_outputs[i]->shape; } HybridOp::HybridOp(std::string name, std::string tag, Map attrs, Array inputs, Array outputs, Stmt body) { @@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map att n->tag = std::move(tag); n->attrs = std::move(attrs); n->inputs = std::move(inputs); - n->outputs = std::move(outputs); + n->symbolic_outputs = std::move(outputs); n->axis = te::GatherLoopVars(body); n->body = std::move(body); data_ = std::move(n); @@ -166,7 +166,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage, Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { - rmap[outputs[i]] = stage->op.output(i); + rmap[symbolic_outputs[i]] = stage->op.output(i); } auto n = make_object(*this); /* This is a story little bit complicated. diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index fae826b926e3..e09cdfe146e1 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -616,7 +616,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); ICHECK(hybrid); Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->outputs, new_hybrid_body[i]); + hybrid->symbolic_outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); From 0897cf2acc261e6d469d0409f3157b4c04dbe12e Mon Sep 17 00:00:00 2001 From: "Colin Y. Li" Date: Fri, 18 Feb 2022 00:33:41 +0800 Subject: [PATCH 3/4] fix tensor replacement in schedule ops --- src/te/schedule/schedule_ops.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 1568df4670af..99e02ccaf943 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -298,9 +298,15 @@ class SchedulePostProc : public StmtExprMutator { private: void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { - replace_buffer_[src] = dst; - replace_realize_[src] = repl_realize; - replace_op_[src->op.get()] = repl_op; + if (!src.same_as(dst)) { + replace_buffer_[src] = dst; + } + if (!src.same_as(repl_realize)) { + replace_realize_[src] = repl_realize; + } + if (!src->op.same_as(repl_op)) { + replace_op_[src->op.get()] = repl_op; + } } // The thread extent scope. std::unordered_map thread_extent_scope_; From d4f2d3d94179b8693fd6e0be0cc78075f35ef6d3 Mon Sep 17 00:00:00 2001 From: "Colin Y. Li" Date: Sat, 19 Feb 2022 00:48:42 +0800 Subject: [PATCH 4/4] fix compute inline --- include/tvm/te/operation.h | 1 + python/tvm/te/hybrid/parser.py | 2 +- src/te/operation/extern_op.cc | 1 + src/te/operation/hybrid_op.cc | 1 + src/te/operation/scan_op.cc | 1 + src/te/operation/tensor_compute_op.cc | 1 + src/te/tensor.cc | 1 + 7 files changed, 7 insertions(+), 1 deletion(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 06fd4ab94922..89074d83e1d6 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -511,6 +511,7 @@ class HybridOpNode : public OperationNode { v->Visit("tag", &tag); v->Visit("attrs", &attrs); v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); v->Visit("symbolic_outputs", &symbolic_outputs); v->Visit("axis", &axis); v->Visit("body", &body); diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 442aeb6f1027..655748783dd1 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -312,7 +312,7 @@ def visit_Assign(self, node): "You should bind a pure name to the tensors", ) self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i)) - rmap[rhs.outputs[i].op] = rhs.output(i) + rmap[rhs.symbolic_outputs[i].op] = rhs.output(i) return utils.replace_io(rhs.body, rmap) _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index b602efcfc28b..84869ccd7775 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -90,6 +90,7 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self, ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = ReplaceTensor(this->body, rmap); + n->outputs = Array(); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; if (rmap.count(t)) { diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 85edcd603c84..5bb645822996 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self, ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = te::ReplaceTensor(this->body, rmap); + n->outputs = Array(); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; if (rmap.count(t)) { diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index ff809f5495f8..a29cc6601795 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -144,6 +144,7 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self, const std::unordered_map& rmap) const { ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); + n->outputs = Array(); for (size_t i = 0; i < n->init.size(); ++i) { if (rmap.count(n->init[i])) { n->init.Set(i, rmap.at(n->init[i])); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 262e5a2b97f4..432df75f03b9 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -85,6 +85,7 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, const std::unordered_map& rmap) const { ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); + n->outputs = Array(); auto intrin = make_object(*(this->intrin.operator->())); intrin->body = ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 7a8bcb590b1e..1f43714ea107 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -79,6 +79,7 @@ String TensorNode::GetNameHint() const { } Tensor Operation::output(size_t n) const { + // cache the output tensors if empty if ((*this)->outputs.empty()) { auto* ptr = static_cast(get_mutable()); size_t num = static_cast((*this)->num_outputs());