Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp Hannk IR #6379

Merged
merged 22 commits into from
Nov 12, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Move the reverse-mutator code into RemoveDeadOps
steven-johnson committed Nov 12, 2021

Verified

This commit was signed with the committer’s verified signature.
jalaziz Jameel Al-Aziz
commit 5bd38e3999f646063a939acd8123d8c7ef562259
7 changes: 1 addition & 6 deletions apps/hannk/interpreter/ops.cpp
Original file line number Diff line number Diff line change
@@ -1923,19 +1923,14 @@ OpPtr OpMutator::visit(std::unique_ptr<OpGroup> op) {

std::vector<OpPtr> ops_new;
ops_new.reserve(old_op_count);
// Would a separate loop for Reverse be better?
for (int i = 0; i < old_op_count; i++) {
const int idx = (direction_ == Reverse) ? (old_op_count - i - 1) : i;
OpPtr sub_op_old = op->take_op(idx);
OpPtr sub_op_old = op->take_op(i);
assert(sub_op_old != nullptr);
OpPtr sub_op_new = mutate(std::move(sub_op_old));
if (sub_op_new != nullptr) {
ops_new.push_back(std::move(sub_op_new));
}
}
if (direction_ == Reverse) {
std::reverse(ops_new.begin(), ops_new.end());
}
// TODO: we don't bother trying to optimize for an unchanged op here. Is it worthwhile?
// TODO: verify that inputs and outputs are still correct. Or recalculate from scratch?
return make_op<OpGroup>(inputs, outputs, std::move(ops_new));
11 changes: 0 additions & 11 deletions apps/hannk/interpreter/ops.h
Original file line number Diff line number Diff line change
@@ -642,15 +642,6 @@ class OpVisitor {

class OpMutator {
public:
enum Direction {
Forward,
Reverse
};

OpMutator() = default;
explicit OpMutator(Direction d)
: direction_(d) {
}
virtual ~OpMutator() = default;

// Convenience function for calling Op::mutate;
@@ -660,8 +651,6 @@ class OpMutator {
}

protected:
const Direction direction_ = Forward;

// Only the classes in the list are allowed to call visit() (to implement mutate_impl())
friend class BinaryOp;
friend class ConcatenationOp;
33 changes: 26 additions & 7 deletions apps/hannk/interpreter/transforms.cpp
Original file line number Diff line number Diff line change
@@ -67,8 +67,7 @@ class RemoveDeadOps : public OpMutator {

public:
// Go in reverse order so removing a dead op enables earlier ops to be seen as dead.
explicit RemoveDeadOps(const Op *root)
: OpMutator(OpMutator::Reverse) {
explicit RemoveDeadOps(const Op *root) {
// Build a set so that we don't have to worry about the root op mutating
for (int i = 0; i < root->output_count(); i++) {
root_outputs_.insert(root->output(i).get());
@@ -83,26 +82,46 @@ class RemoveDeadOps : public OpMutator {
OpPtr visit_leaf(OpPtr op) override {
if (is_dead(op.get())) {
removed_++;
// std::cout << "REMOVE LEAF: "; op->dump(std::cout);
return nullptr;
} else {
return op;
}
}

OpPtr visit(std::unique_ptr<OpGroup> op) override {
OpPtr new_op = OpMutator::visit(std::move(op));
// Don't call super; roll our own code to go in reverse order.
// TODO: is this even the right thing to do? May be better to
// to a graph traverse andkeep removing ops until none get removed.

std::vector<TensorPtr> inputs = op->inputs();
std::vector<TensorPtr> outputs = op->outputs();

const int old_op_count = op->op_count();

std::vector<OpPtr> ops_new;
ops_new.reserve(old_op_count);
for (int i = 0; i < old_op_count; i++) {
const int idx = (old_op_count - i - 1);
OpPtr sub_op_old = op->take_op(idx);
assert(sub_op_old != nullptr);
OpPtr sub_op_new = mutate(std::move(sub_op_old));
if (sub_op_new != nullptr) {
ops_new.push_back(std::move(sub_op_new));
}
}
std::reverse(ops_new.begin(), ops_new.end());
auto new_op_group = make_op<OpGroup>(inputs, outputs, std::move(ops_new));

// If the OpGroup is empty after mutation, remove it as well.
const OpGroup *new_op_group = cast_op<OpGroup>(new_op.get());
if (new_op_group && new_op_group->op_count() == 0) {
removed_++;
// std::cout << "REMOVE GROUP: "; op->dump(std::cout);
return nullptr;
}

return new_op;
return new_op_group;
}

private:
};

} // namespace