Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
*/
LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);

/*!
* \brief Whether the node is element-wise.
* \return whether the node is element-wise.
*/
bool IsElemWise(const NodeRef& node);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not pass in NodeRef, which can be anything. Consider pass in Stmt and Array.



} // namespace ir
} // namespace tvm
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ Map<IterVar, Range> InferBound(Schedule sch);
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

/*!
* \brief To automatically fuse the element-wise operations.
*
* \param s The schedule to be fused.
*/
void AutoFuseElemWise(Schedule sch);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to AutoInlineElemWise


} // namespace schedule
} // namespace tvm
#endif // TVM_SCHEDULE_PASS_H_
2 changes: 1 addition & 1 deletion python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def compute_root(self):
parent : Stage
The parent stage
"""
_api_internal._StageComputeInline(self)
_api_internal._StageComputeRoot(self)

def reorder(self, *args):
"""reorder the arguments in the specified order.
Expand Down
5 changes: 5 additions & 0 deletions src/api/api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
namespace tvm {
namespace schedule {

TVM_REGISTER_API(_schedule_AutoFuseElemWise)
.set_body([](TVMArgs args, TVMRetValue* ret) {
AutoFuseElemWise(args[0]);
});

#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API(_schedule_## PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expand Down
57 changes: 57 additions & 0 deletions src/pass/elem_wise_detector.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*!
* Copyright (c) 2016 by Contributors
* \file elem_wise_detector.cc
*/
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
#include <tvm/operation.h>

namespace tvm {
namespace ir {

class ElemWiseDetector : public IRVisitor {
public:
explicit ElemWiseDetector(Array<IterVar> axis) : axis_(axis) {}

void Visit(const NodeRef& e) final {
if (!is_elem_wise_)
return;
IRVisitor::Visit(e);
}

void Visit_(const Call* op) final {
Array<Expr> axis = op->args;
if (axis_.size() != axis.size()) {
is_elem_wise_ = false;
return;
}

for (size_t i = 0; i < axis_.size(); ++i) {
const Variable *v1 = axis_[i]->var.as<Variable>();
const Variable *v2 = axis[i].as<Variable>();
if (!(v1 && v2) || (v1 != v2)) {
is_elem_wise_ = false;
return;
}
}
IRVisitor::Visit_(op);
}

bool is_elem_wise_{true};

private:
Array<IterVar> axis_;
};


bool IsElemWise(const NodeRef& node) {
if (const ComputeOpNode* compute = node.as<ComputeOpNode>()) {
ElemWiseDetector v = ElemWiseDetector(compute->axis);
v.Visit(compute->body);
return v.is_elem_wise_;
}
return false;
}

} // namespace ir
} // namespace tvm
34 changes: 34 additions & 0 deletions src/schedule/fusion.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*!
* Copyright (c) 2016 by Contributors
* \file schedule.cc
*/
#include <tvm/schedule_pass.h>
#include <tvm/ir_pass.h>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename the file to be auto_inline_elemwise.cc


namespace tvm {
namespace schedule {

namespace {
inline bool is_stage_scheduled(const Stage& s) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it as a member function of stage

return !(s->relations.empty() && s->attach_type == kNone);
}
}

void AutoFuseElemWise(Schedule sch) {
for (Stage s : sch->stages) {
if (!is_stage_scheduled(s) && ir::IsElemWise(s->op)) {
bool is_root = false;
for (auto r : sch->roots) {
if (r == s->op) {
is_root = true;
break;
}
}
if (!is_root)
s.compute_inline();
}
}
}

} // namespace schedule
} // namespace tvm
16 changes: 16 additions & 0 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,24 @@ def test_schedule2():
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)

def test_fusion():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')
C = tvm.placeholder((m, n), name='C')
T1 = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='T1')
T2 = tvm.compute((m, n), lambda i, j: T1(i, j) + C(i, j), name='T2')

s = tvm.Schedule(T2.op)
tvm.schedule.AutoFuseElemWise(s)
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
print(stmt)


if __name__ == "__main__":
test_schedule0()
test_schedule1()
test_schedule2()
test_fusion()