Skip to content

Commit 73cf51b

Browse files
authored
[TE][Fix] Comparison of the output tensor (#9829)
* [TE][Fix] Comparison of the output tensor * fix hybrid op issue * fix tensor replacement in schedule ops * fix compute inline
1 parent b445d66 commit 73cf51b

File tree

11 files changed

+50
-21
lines changed

11 files changed

+50
-21
lines changed

include/tvm/te/operation.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,11 @@ class TVM_DLL OperationNode : public Object {
5959
std::string name;
6060
/*! \brief optional tag of the operation */
6161
std::string tag;
62-
/*! \brief additional attributes of the operation*/
62+
/*! \brief additional attributes of the operation */
6363
Map<String, ObjectRef> attrs;
64+
/*! \brief output tensors */
65+
Array<Tensor> outputs;
66+
6467
// virtual destructor.
6568
virtual ~OperationNode() {}
6669
/*! \return number of outputs */
@@ -473,7 +476,7 @@ class HybridOpNode : public OperationNode {
473476
/*! \brief The input tensors */
474477
Array<Tensor> inputs;
475478
/*! \brief Symbolic placeholder representation of outputs */
476-
Array<Tensor> outputs;
479+
Array<Tensor> symbolic_outputs;
477480
/*! \brief The axis of iterations */
478481
Array<IterVar> axis;
479482
/*! \brief the statement that generates the computation. This is
@@ -509,6 +512,7 @@ class HybridOpNode : public OperationNode {
509512
v->Visit("attrs", &attrs);
510513
v->Visit("inputs", &inputs);
511514
v->Visit("outputs", &outputs);
515+
v->Visit("symbolic_outputs", &symbolic_outputs);
512516
v->Visit("axis", &axis);
513517
v->Visit("body", &body);
514518
}

python/tvm/te/hybrid/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def visit_Assign(self, node):
312312
"You should bind a pure name to the tensors",
313313
)
314314
self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
315-
rmap[rhs.outputs[i].op] = rhs.output(i)
315+
rmap[rhs.symbolic_outputs[i].op] = rhs.output(i)
316316
return utils.replace_io(rhs.body, rmap)
317317

318318
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")

python/tvm/te/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,12 @@ def __eq__(self, other):
8686
if isinstance(other, _expr.ExprOp):
8787
return _expr.EqualOp(self, other)
8888
return False
89+
if self.same_as(other):
90+
return True
8991
if self.ndim == 0 and other.ndim == 0:
9092
raise ValueError(
9193
"Equal == comparison among rank-0 tensor is ambiguous, "
92-
"use Tensor.equal for content expression equvalence, "
93-
"use Tensor.same_as for exact reference comparison"
94+
"use Tensor.equal for content expression equvalence."
9495
)
9596
return _ffi_api.TensorEqual(self, other)
9697

src/te/operation/extern_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self,
9090
ICHECK_EQ(self.operator->(), this);
9191
auto n = make_object<ExternOpNode>(*this);
9292
n->body = ReplaceTensor(this->body, rmap);
93+
n->outputs = Array<Tensor>();
9394
for (size_t i = 0; i < n->inputs.size(); ++i) {
9495
Tensor t = n->inputs[i];
9596
if (rmap.count(t)) {

src/te/operation/hybrid_op.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
4949

5050
TVM_REGISTER_NODE_TYPE(HybridOpNode);
5151

52-
int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); }
52+
int HybridOpNode::num_outputs() const { return static_cast<int>(symbolic_outputs.size()); }
5353

5454
Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; }
5555

56-
DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; }
56+
DataType HybridOpNode::output_dtype(size_t i) const { return symbolic_outputs[i]->dtype; }
5757

58-
Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
58+
Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return symbolic_outputs[i]->shape; }
5959

6060
HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
6161
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
@@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> att
6767
n->tag = std::move(tag);
6868
n->attrs = std::move(attrs);
6969
n->inputs = std::move(inputs);
70-
n->outputs = std::move(outputs);
70+
n->symbolic_outputs = std::move(outputs);
7171
n->axis = te::GatherLoopVars(body);
7272
n->body = std::move(body);
7373
data_ = std::move(n);
@@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self,
104104
ICHECK_EQ(self.operator->(), this);
105105
auto n = make_object<HybridOpNode>(*this);
106106
n->body = te::ReplaceTensor(this->body, rmap);
107+
n->outputs = Array<Tensor>();
107108
for (size_t i = 0; i < n->inputs.size(); ++i) {
108109
Tensor t = n->inputs[i];
109110
if (rmap.count(t)) {
@@ -166,7 +167,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage,
166167
Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
167168
std::unordered_map<Tensor, Tensor> rmap;
168169
for (int i = 0; i < this->num_outputs(); ++i) {
169-
rmap[outputs[i]] = stage->op.output(i);
170+
rmap[symbolic_outputs[i]] = stage->op.output(i);
170171
}
171172
auto n = make_object<HybridOpNode>(*this);
172173
/* This is a story little bit complicated.

src/te/operation/scan_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
4040
TVM_REGISTER_NODE_TYPE(ScanOpNode);
4141

4242
int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); }
43+
4344
Array<IterVar> ScanOpNode::root_iter_vars() const {
4445
Array<IterVar> ret{scan_axis};
4546
for (IterVar iv : spatial_axis_) {
@@ -143,6 +144,7 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self,
143144
const std::unordered_map<Tensor, Tensor>& rmap) const {
144145
ICHECK_EQ(self.operator->(), this);
145146
auto n = make_object<ScanOpNode>(*this);
147+
n->outputs = Array<Tensor>();
146148
for (size_t i = 0; i < n->init.size(); ++i) {
147149
if (rmap.count(n->init[i])) {
148150
n->init.Set(i, rmap.at(n->init[i]));

src/te/operation/tensor_compute_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self,
8585
const std::unordered_map<Tensor, Tensor>& rmap) const {
8686
ICHECK_EQ(self.operator->(), this);
8787
auto n = make_object<TensorComputeOpNode>(*this);
88+
n->outputs = Array<Tensor>();
8889
auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
8990
intrin->body = ReplaceTensor(this->intrin->body, rmap);
9091
if (intrin->reduce_init.defined()) {

src/te/schedule/schedule_dataflow_rewrite.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
616616
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
617617
ICHECK(hybrid);
618618
Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
619-
hybrid->outputs, new_hybrid_body[i]);
619+
hybrid->symbolic_outputs, new_hybrid_body[i]);
620620
op = op->ReplaceInputs(op, repl);
621621
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
622622
repl[s->op.output(idx)] = op.output(idx);

src/te/schedule/schedule_ops.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,15 @@ class SchedulePostProc : public StmtExprMutator {
298298
private:
299299
void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(),
300300
Operation repl_op = Operation()) {
301-
replace_buffer_[src] = dst;
302-
replace_realize_[src] = repl_realize;
303-
replace_op_[src->op.get()] = repl_op;
301+
if (!src.same_as(dst)) {
302+
replace_buffer_[src] = dst;
303+
}
304+
if (!src.same_as(repl_realize)) {
305+
replace_realize_[src] = repl_realize;
306+
}
307+
if (!src->op.same_as(repl_op)) {
308+
replace_op_[src->op.get()] = repl_op;
309+
}
304310
}
305311
// The thread extent scope.
306312
std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;

src/te/tensor.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,22 @@ String TensorNode::GetNameHint() const {
7878
return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index));
7979
}
8080

81-
Tensor Operation::output(size_t i) const {
82-
auto node = make_object<TensorNode>();
83-
node->op = *this;
84-
node->value_index = i;
85-
node->dtype = (*this)->output_dtype(i);
86-
node->shape = (*this)->output_shape(i);
87-
return Tensor(node);
81+
Tensor Operation::output(size_t n) const {
82+
// cache the output tensors if empty
83+
if ((*this)->outputs.empty()) {
84+
auto* ptr = static_cast<OperationNode*>(get_mutable());
85+
size_t num = static_cast<size_t>((*this)->num_outputs());
86+
for (size_t i = 0; i < num; ++i) {
87+
auto node = make_object<TensorNode>();
88+
node->op = *this;
89+
node->value_index = i;
90+
node->dtype = (*this)->output_dtype(i);
91+
node->shape = (*this)->output_shape(i);
92+
ptr->outputs.push_back(Tensor(node));
93+
}
94+
}
95+
ICHECK_LT(n, (*this)->outputs.size());
96+
return (*this)->outputs[n];
8897
}
8998

9099
Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {

0 commit comments

Comments
 (0)