Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ class TVM_DLL OperationNode : public Object {
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
/*! \brief additional attributes of the operation */
Map<String, ObjectRef> attrs;
/*! \brief output tensors */
Array<Tensor> outputs;

// virtual destructor.
virtual ~OperationNode() {}
/*! \return number of outputs */
Expand Down Expand Up @@ -473,7 +476,7 @@ class HybridOpNode : public OperationNode {
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
Array<Tensor> symbolic_outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is
Expand Down Expand Up @@ -509,6 +512,7 @@ class HybridOpNode : public OperationNode {
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("symbolic_outputs", &symbolic_outputs);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def visit_Assign(self, node):
"You should bind a pure name to the tensors",
)
self.add_symbol(node.targets[i].id, Symbol.GlobalBuffer, rhs.output(i))
rmap[rhs.outputs[i].op] = rhs.output(i)
rmap[rhs.symbolic_outputs[i].op] = rhs.output(i)
return utils.replace_io(rhs.body, rmap)

_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/te/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ def __eq__(self, other):
if isinstance(other, _expr.ExprOp):
return _expr.EqualOp(self, other)
return False
if self.same_as(other):
return True
if self.ndim == 0 and other.ndim == 0:
raise ValueError(
"Equal == comparison among rank-0 tensor is ambiguous, "
"use Tensor.equal for content expression equvalence, "
"use Tensor.same_as for exact reference comparison"
"use Tensor.equal for content expression equvalence."
)
return _ffi_api.TensorEqual(self, other)

Expand Down
1 change: 1 addition & 0 deletions src/te/operation/extern_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Operation ExternOpNode::ReplaceInputs(const Operation& self,
ICHECK_EQ(self.operator->(), this);
auto n = make_object<ExternOpNode>(*this);
n->body = ReplaceTensor(this->body, rmap);
n->outputs = Array<Tensor>();
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
Expand Down
11 changes: 6 additions & 5 deletions src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_NODE_TYPE(HybridOpNode);

int HybridOpNode::num_outputs() const { return static_cast<int>(outputs.size()); }
int HybridOpNode::num_outputs() const { return static_cast<int>(symbolic_outputs.size()); }

Array<IterVar> HybridOpNode::root_iter_vars() const { return this->axis; }

DataType HybridOpNode::output_dtype(size_t i) const { return outputs[i]->dtype; }
DataType HybridOpNode::output_dtype(size_t i) const { return symbolic_outputs[i]->dtype; }

Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return outputs[i]->shape; }
Array<PrimExpr> HybridOpNode::output_shape(size_t i) const { return symbolic_outputs[i]->shape; }

HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body) {
Expand All @@ -67,7 +67,7 @@ HybridOp::HybridOp(std::string name, std::string tag, Map<String, ObjectRef> att
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->inputs = std::move(inputs);
n->outputs = std::move(outputs);
n->symbolic_outputs = std::move(outputs);
n->axis = te::GatherLoopVars(body);
n->body = std::move(body);
data_ = std::move(n);
Expand Down Expand Up @@ -104,6 +104,7 @@ Operation HybridOpNode::ReplaceInputs(const Operation& self,
ICHECK_EQ(self.operator->(), this);
auto n = make_object<HybridOpNode>(*this);
n->body = te::ReplaceTensor(this->body, rmap);
n->outputs = Array<Tensor>();
for (size_t i = 0; i < n->inputs.size(); ++i) {
Tensor t = n->inputs[i];
if (rmap.count(t)) {
Expand Down Expand Up @@ -166,7 +167,7 @@ Stmt HybridOpNode::BuildProvide(const Stage& stage,
Stmt ret = AttrStmt(make_zero(DataType::Int(32)), tir::attr::extern_scope, 0, this->body);
std::unordered_map<Tensor, Tensor> rmap;
for (int i = 0; i < this->num_outputs(); ++i) {
rmap[outputs[i]] = stage->op.output(i);
rmap[symbolic_outputs[i]] = stage->op.output(i);
}
auto n = make_object<HybridOpNode>(*this);
/* This is a story little bit complicated.
Expand Down
2 changes: 2 additions & 0 deletions src/te/operation/scan_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ScanOpNode);

int ScanOpNode::num_outputs() const { return static_cast<int>(update.size()); }

Array<IterVar> ScanOpNode::root_iter_vars() const {
Array<IterVar> ret{scan_axis};
for (IterVar iv : spatial_axis_) {
Expand Down Expand Up @@ -143,6 +144,7 @@ Operation ScanOpNode::ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
ICHECK_EQ(self.operator->(), this);
auto n = make_object<ScanOpNode>(*this);
n->outputs = Array<Tensor>();
for (size_t i = 0; i < n->init.size(); ++i) {
if (rmap.count(n->init[i])) {
n->init.Set(i, rmap.at(n->init[i]));
Expand Down
1 change: 1 addition & 0 deletions src/te/operation/tensor_compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Operation TensorComputeOpNode::ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const {
ICHECK_EQ(self.operator->(), this);
auto n = make_object<TensorComputeOpNode>(*this);
n->outputs = Array<Tensor>();
auto intrin = make_object<TensorIntrinNode>(*(this->intrin.operator->()));
intrin->body = ReplaceTensor(this->intrin->body, rmap);
if (intrin->reduce_init.defined()) {
Expand Down
2 changes: 1 addition & 1 deletion src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ void InjectInline(ScheduleNode* sch, bool feature_extraction_mode) {
const HybridOpNode* hybrid = sch->stages[i]->op.as<HybridOpNode>();
ICHECK(hybrid);
Operation op = HybridOp(hybrid->name, hybrid->tag, hybrid->attrs, hybrid->inputs,
hybrid->outputs, new_hybrid_body[i]);
hybrid->symbolic_outputs, new_hybrid_body[i]);
op = op->ReplaceInputs(op, repl);
for (int idx = 0; idx < s->op->num_outputs(); ++idx) {
repl[s->op.output(idx)] = op.output(idx);
Expand Down
12 changes: 9 additions & 3 deletions src/te/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,15 @@ class SchedulePostProc : public StmtExprMutator {
private:
void AddReplace(Tensor src, Tensor dst, Tensor repl_realize = Tensor(),
Operation repl_op = Operation()) {
replace_buffer_[src] = dst;
replace_realize_[src] = repl_realize;
replace_op_[src->op.get()] = repl_op;
if (!src.same_as(dst)) {
replace_buffer_[src] = dst;
}
if (!src.same_as(repl_realize)) {
replace_realize_[src] = repl_realize;
}
if (!src->op.same_as(repl_op)) {
replace_op_[src->op.get()] = repl_op;
}
}
// The thread extent scope.
std::unordered_map<const Object*, PrimExpr> thread_extent_scope_;
Expand Down
23 changes: 16 additions & 7 deletions src/te/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,22 @@ String TensorNode::GetNameHint() const {
return op->num_outputs() == 1 ? op->name : (op->name + ".v" + std::to_string(value_index));
}

Tensor Operation::output(size_t i) const {
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
return Tensor(node);
Tensor Operation::output(size_t n) const {
// cache the output tensors if empty
if ((*this)->outputs.empty()) {
auto* ptr = static_cast<OperationNode*>(get_mutable());
size_t num = static_cast<size_t>((*this)->num_outputs());
for (size_t i = 0; i < num; ++i) {
auto node = make_object<TensorNode>();
node->op = *this;
node->value_index = i;
node->dtype = (*this)->output_dtype(i);
node->shape = (*this)->output_shape(i);
ptr->outputs.push_back(Tensor(node));
}
}
ICHECK_LT(n, (*this)->outputs.size());
return (*this)->outputs[n];
}

Tensor::Tensor(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index) {
Expand Down
4 changes: 4 additions & 0 deletions tests/python/unittest/test_te_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def test_tensor():
assert T.op.output(0).__hash__() == T.__hash__()
d = {T.op.output(0): 1}
assert d[T] == 1
assert T == T.op.output(0)
assert T.same_as(T.op.output(0))
assert T[0][0][0].astype("float16").dtype == "float16"


Expand All @@ -49,6 +51,8 @@ def test_rank_zero():
print(T)
print(T.op.body)
assert tuple(T.shape) == ()
assert T == T.op.output(0)
assert T.same_as(T.op.output(0))


def test_conv1d():
Expand Down