Skip to content

Commit d4f2d3d

Browse files
committed
fix compute inline
1 parent 0897cf2 commit d4f2d3d

File tree

7 files changed

+7
-1
lines changed

7 files changed

+7
-1
lines changed

include/tvm/te/operation.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ class HybridOpNode : public OperationNode {
511511
v->Visit("tag", &tag);
512512
v->Visit("attrs", &attrs);
513513
v->Visit("inputs", &inputs);
514+
v->Visit("outputs", &outputs);
514515
v->Visit("symbolic_outputs", &symbolic_outputs);
515516
v->Visit("axis", &axis);
516517
v->Visit("body", &body);

python/tvm/te/hybrid/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def visit_Assign(self, node):
312312
"You should bind a pure name to the tensors",
313313
)
314314
self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
315-
rmap[rhs.outputs[i].op] = rhs.output(i)
315+
rmap[rhs.symbolic_outputs[i].op] = rhs.output(i)
316316
return utils.replace_io(rhs.body, rmap)
317317

318318
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")

src/te/operation/extern_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self,
9090
ICHECK_EQ(self.operator->(), this);
9191
auto n = make_object<ExternOpNode>(*this);
9292
n->body = ReplaceTensor(this->body, rmap);
93+
n->outputs = Array<Tensor>();
9394
for (size_t i = 0; i < n->inputs.size(); ++i) {
9495
Tensor t = n->inputs[i];
9596
if (rmap.count(t)) {

src/te/operation/hybrid_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self,
104104
ICHECK_EQ(self.operator->(), this);
105105
auto n = make_object<HybridOpNode>(*this);
106106
n->body = te::ReplaceTensor(this->body, rmap);
107+
n->outputs = Array<Tensor>();
107108
for (size_t i = 0; i < n->inputs.size(); ++i) {
108109
Tensor t = n->inputs[i];
109110
if (rmap.count(t)) {

src/te/operation/scan_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self,
144144
const std::unordered_map<Tensor, Tensor>& rmap) const {
145145
ICHECK_EQ(self.operator->(), this);
146146
auto n = make_object<ScanOpNode>(*this);
147+
n->outputs = Array<Tensor>();
147148
for (size_t i = 0; i < n->init.size(); ++i) {
148149
if (rmap.count(n->init[i])) {
149150
n->init.Set(i, rmap.at(n->init[i]));

src/te/operation/tensor_compute_op.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self,
8585
const std::unordered_map<Tensor, Tensor>& rmap) const {
8686
ICHECK_EQ(self.operator->(), this);
8787
auto n = make_object<TensorComputeOpNode>(*this);
88+
n->outputs = Array<Tensor>();
8889
auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
8990
intrin->body = ReplaceTensor(this->intrin->body, rmap);
9091
if (intrin->reduce_init.defined()) {

src/te/tensor.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ String TensorNode::GetNameHint() const {
7979
}
8080

8181
Tensor Operation::output(size_t n) const {
82+
// cache the output tensors if empty
8283
if ((*this)->outputs.empty()) {
8384
auto* ptr = static_cast<OperationNode*>(get_mutable());
8485
size_t num = static_cast<size_t>((*this)->num_outputs());

0 commit comments

Comments
 (0)