From bfc986d596de0eba2e5521e7ffd5e4992372b80b Mon Sep 17 00:00:00 2001 From: AndrewZhaoLuo Date: Tue, 8 Mar 2022 16:08:54 -0800 Subject: [PATCH] Revert "[TE][Fix] Comparison of the output tensor (#9829)" This reverts commit 73cf51b8246f1f0b5c0bff6fa206d154e053199b. --- include/tvm/te/operation.h | 8 ++----- python/tvm/te/hybrid/parser.py | 2 +- python/tvm/te/tensor.py | 5 ++--- src/te/operation/extern_op.cc | 1 - src/te/operation/hybrid_op.cc | 11 +++++----- src/te/operation/scan_op.cc | 2 -- src/te/operation/tensor_compute_op.cc | 1 - src/te/schedule/schedule_dataflow_rewrite.cc | 2 +- src/te/schedule/schedule_ops.cc | 12 +++------- src/te/tensor.cc | 23 ++++++-------------- tests/python/unittest/test_te_tensor.py | 4 ---- 11 files changed, 21 insertions(+), 50 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 99c86f0d58f7..e91a0930f37b 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -59,11 +59,8 @@ class TVM_DLL OperationNode : public Object { std::string name; /*! \brief optional tag of the operation */ std::string tag; - /*! \brief additional attributes of the operation */ + /*! \brief additional attributes of the operation*/ Map attrs; - /*! \brief output tensors */ - Array outputs; - // virtual destructor. virtual ~OperationNode() {} /*! \return number of outputs */ @@ -477,7 +474,7 @@ class HybridOpNode : public OperationNode { /*! \brief The input tensors */ Array inputs; /*! \brief Symbolic placeholder representation of outputs */ - Array symbolic_outputs; + Array outputs; /*! \brief The axis of iterations */ Array axis; /*! \brief the statement that generates the computation. This is @@ -513,7 +510,6 @@ class HybridOpNode : public OperationNode { v->Visit("attrs", &attrs); v->Visit("inputs", &inputs); v->Visit("outputs", &outputs); - v->Visit("symbolic_outputs", &symbolic_outputs); v->Visit("axis", &axis); v->Visit("body", &body); } diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 655748783dd1..442aeb6f1027 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -312,7 +312,7 @@ def visit_Assign(self, node): "You should bind a pure name to the tensors", ) self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i)) - rmap[rhs.symbolic_outputs[i].op] = rhs.output(i) + rmap[rhs.outputs[i].op] = rhs.output(i) return utils.replace_io(rhs.body, rmap) _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index b5609375b8aa..fc85d830c91a 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -86,12 +86,11 @@ def __eq__(self, other): if isinstance(other, _expr.ExprOp): return _expr.EqualOp(self, other) return False - if self.same_as(other): - return True if self.ndim == 0 and other.ndim == 0: raise ValueError( "Equal == comparison among rank-0 tensor is ambiguous, " - "use Tensor.equal for content expression equvalence." + "use Tensor.equal for content expression equvalence, " + "use Tensor.same_as for exact reference comparison" ) return _ffi_api.TensorEqual(self, other) diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 84869ccd7775..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -90,7 +90,6 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self, ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = ReplaceTensor(this->body, rmap); - n->outputs = Array(); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; if (rmap.count(t)) { diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 5bb645822996..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(HybridOpNode); -int HybridOpNode::num_outputs() const { return static_cast(symbolic_outputs.size()); } +int HybridOpNode::num_outputs() const { return static_cast(outputs.size()); } Array HybridOpNode::root_iter_vars() const { return this->axis; } -DataType HybridOpNode::output_dtype(size_t i) const { return symbolic_outputs[i]->dtype; } +DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; } -Array HybridOpNode::output_shape(size_t i) const { return symbolic_outputs[i]->shape; } +Array HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; } HybridOp::HybridOp(std::string name, std::string tag, Map attrs, Array inputs, Array outputs, Stmt body) { @@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map att n->tag = std::move(tag); n->attrs = std::move(attrs); n->inputs = std::move(inputs); - n->symbolic_outputs = std::move(outputs); + n->outputs = std::move(outputs); n->axis = te::GatherLoopVars(body); n->body = std::move(body); data_ = std::move(n); @@ -104,7 +104,6 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self, ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); n->body = te::ReplaceTensor(this->body, rmap); - n->outputs = Array(); for (size_t i = 0; i < n->inputs.size(); ++i) { Tensor t = n->inputs[i]; if (rmap.count(t)) { @@ -167,7 +166,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage, Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body); std::unordered_map rmap; for (int i = 0; i < this->num_outputs(); ++i) { - rmap[symbolic_outputs[i]] = stage->op.output(i); + rmap[outputs[i]] = stage->op.output(i); } auto n = make_object(*this); /* This is a story little bit complicated. diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a29cc6601795..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -40,7 +40,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(ScanOpNode); int ScanOpNode::num_outputs() const { return static_cast(update.size()); } - Array ScanOpNode::root_iter_vars() const { Array ret{scan_axis}; for (IterVar iv : spatial_axis_) { @@ -144,7 +143,6 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self, const std::unordered_map& rmap) const { ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); - n->outputs = Array(); for (size_t i = 0; i < n->init.size(); ++i) { if (rmap.count(n->init[i])) { n->init.Set(i, rmap.at(n->init[i])); diff --git a/src/te/operation/tensor_compute_op.cc b/src/te/operation/tensor_compute_op.cc index 432df75f03b9..262e5a2b97f4 100644 --- a/src/te/operation/tensor_compute_op.cc +++ b/src/te/operation/tensor_compute_op.cc @@ -85,7 +85,6 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self, const std::unordered_map& rmap) const { ICHECK_EQ(self.operator->(), this); auto n = make_object(*this); - n->outputs = Array(); auto intrin = make_object(*(this->intrin.operator->())); intrin->body = ReplaceTensor(this->intrin->body, rmap); if (intrin->reduce_init.defined()) { diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index e09cdfe146e1..fae826b926e3 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -616,7 +616,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) { const HybridOpNode* hybrid = sch->stages[i]->op.as(); ICHECK(hybrid); Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs, - hybrid->symbolic_outputs, new_hybrid_body[i]); + hybrid->outputs, new_hybrid_body[i]); op = op->ReplaceInputs(op, repl); for (int idx = 0; idx < s->op->num_outputs(); ++idx) { repl[s->op.output(idx)] = op.output(idx); diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 47ef4af1c4c4..75736d0333da 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -339,15 +339,9 @@ class SchedulePostProc : public StmtExprMutator { private: void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(), Operation repl_op = Operation()) { - if (!src.same_as(dst)) { - replace_buffer_[src] = dst; - } - if (!src.same_as(repl_realize)) { - replace_realize_[src] = repl_realize; - } - if (!src->op.same_as(repl_op)) { - replace_op_[src->op.get()] = repl_op; - } + replace_buffer_[src] = dst; + replace_realize_[src] = repl_realize; + replace_op_[src->op.get()] = repl_op; } // The thread extent scope. std::unordered_map thread_extent_scope_; diff --git a/src/te/tensor.cc b/src/te/tensor.cc index 1f43714ea107..1d75761216f1 100644 --- a/src/te/tensor.cc +++ b/src/te/tensor.cc @@ -78,22 +78,13 @@ String TensorNode::GetNameHint() const { return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index)); } -Tensor Operation::output(size_t n) const { - // cache the output tensors if empty - if ((*this)->outputs.empty()) { - auto* ptr = static_cast(get_mutable()); - size_t num = static_cast((*this)->num_outputs()); - for (size_t i = 0; i < num; ++i) { - auto node = make_object(); - node->op = *this; - node->value_index = i; - node->dtype = (*this)->output_dtype(i); - node->shape = (*this)->output_shape(i); - ptr->outputs.push_back(Tensor(node)); - } - } - ICHECK_LT(n, (*this)->outputs.size()); - return (*this)->outputs[n]; +Tensor Operation::output(size_t i) const { + auto node = make_object(); + node->op = *this; + node->value_index = i; + node->dtype = (*this)->output_dtype(i); + node->shape = (*this)->output_shape(i); + return Tensor(node); } Tensor::Tensor(Array shape, DataType dtype, Operation op, int value_index) { diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 6b09410af6c4..6958888e9bb6 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -37,8 +37,6 @@ def test_tensor(): assert T.op.output(0).__hash__() == T.__hash__() d = {T.op.output(0): 1} assert d[T] == 1 - assert T == T.op.output(0) - assert T.same_as(T.op.output(0)) assert T[0][0][0].astype("float16").dtype == "float16" @@ -51,8 +49,6 @@ def test_rank_zero(): print(T) print(T.op.body) assert tuple(T.shape) == () - assert T == T.op.output(0) - assert T.same_as(T.op.output(0)) def test_conv1d():