Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2408,27 +2408,15 @@ class BoxesTouched : public IRGraphVisitor {
}
}

class CountVars : public IRVisitor {
using IRVisitor::visit;

void visit(const Variable *var) override {
count++;
}

public:
int count = 0;
CountVars() = default;
};

// We get better simplification if we directly substitute mins
// and maxes in, but this can also cause combinatorial code
// explosion. Here we manage this by only substituting in
// reasonably-sized expressions. We determine the size by
// counting the number of var nodes.
bool is_small_enough_to_substitute(const Expr &e) {
CountVars c;
e.accept(&c);
return c.count < 10;
int count = 0;
visit_with(e, [&](auto *, const Variable *) { count++; });
return count < 10;
}

void push_var(const string &name) {
Expand Down Expand Up @@ -2744,23 +2732,13 @@ class BoxesTouched : public IRGraphVisitor {
}

vector<const Variable *> find_free_vars(const Expr &e) {
class FindFreeVars : public IRVisitor {
using IRVisitor::visit;
void visit(const Variable *op) override {
if (scope.contains(op->name)) {
result.push_back(op);
}
vector<const Variable *> result;
visit_with(e, [&](auto *, const Variable *op) {
if (scope.contains(op->name)) {
result.push_back(op);
}

public:
const Scope<Interval> &scope;
vector<const Variable *> result;
FindFreeVars(const Scope<Interval> &s)
: scope(s) {
}
} finder(scope);
e.accept(&finder);
return finder.result;
});
return result;
}

void visit(const IfThenElse *op) override {
Expand Down
46 changes: 18 additions & 28 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,30 @@ bool var_name_match(const string &candidate, const string &var) {
return (candidate == var) || Internal::ends_with(candidate, "." + var);
}

class DependsOnBoundsInference : public IRVisitor {
using IRVisitor::visit;

void visit(const Variable *var) override {
if (ends_with(var->name, ".max") ||
ends_with(var->name, ".min")) {
result = true;
}
}

void visit(const Call *op) override {
if (op->name == Call::buffer_get_min ||
op->name == Call::buffer_get_max) {
result = true;
} else {
IRVisitor::visit(op);
}
}

public:
bool result = false;
DependsOnBoundsInference() = default;
};

bool depends_on_bounds_inference(const Expr &e) {
DependsOnBoundsInference d;
e.accept(&d);
return d.result;
bool result = false;
visit_with(
e,
[&](auto *, const Variable *var) {
if (ends_with(var->name, ".max") ||
ends_with(var->name, ".min")) {
result = true;
} //
},
[&](auto *self, const Call *op) {
if (op->name == Call::buffer_get_min ||
op->name == Call::buffer_get_max) {
result = true;
} else {
self->visit_base(op);
} //
});
return result;
}

/** Compute the bounds of the value of some variable defined by an
* inner let stmt or for loop. E.g. for the stmt:
*
*
* for x from 0 to 10:
* let y = x + 2;
*
Expand Down
64 changes: 25 additions & 39 deletions src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1250,50 +1250,36 @@ const Call *Function::is_wrapper() const {
}
}

namespace {

// Replace all calls to functions listed in 'substitutions' with their wrappers.
class SubstituteCalls : public IRMutator {
using IRMutator::visit;

const map<FunctionPtr, FunctionPtr> &substitutions;

Expr visit(const Call *c) override {
Expr expr = IRMutator::visit(c);
c = expr.as<Call>();
internal_assert(c);

if ((c->call_type == Call::Halide) &&
c->func.defined() &&
substitutions.count(c->func)) {
auto it = substitutions.find(c->func);
internal_assert(it != substitutions.end())
<< "Function not in environment: " << c->func->name << "\n";
FunctionPtr subs = it->second;
debug(4) << "...Replace call to Func \"" << c->name << "\" with "
<< "\"" << subs->name << "\"\n";
expr = Call::make(c->type, subs->name, c->args, c->call_type,
subs, c->value_index,
c->image, c->param);
}
return expr;
}

public:
SubstituteCalls(const map<FunctionPtr, FunctionPtr> &substitutions)
: substitutions(substitutions) {
}
};

} // anonymous namespace

Function &Function::substitute_calls(const map<FunctionPtr, FunctionPtr> &substitutions) {
debug(4) << "Substituting calls in " << name() << "\n";
if (substitutions.empty()) {
return *this;
}
SubstituteCalls subs_calls(substitutions);
contents->mutate(&subs_calls);

// Replace all calls to functions listed in 'substitutions' with their wrappers.
auto m = LambdaMutator{
[&](auto *self, const Call *c) {
Expr expr = self->visit_base(c);
c = expr.as<Call>();
internal_assert(c);

if ((c->call_type == Call::Halide) &&
c->func.defined() &&
substitutions.count(c->func)) {
auto it = substitutions.find(c->func);
internal_assert(it != substitutions.end())
<< "Function not in environment: " << c->func->name << "\n";
FunctionPtr subs = it->second;
debug(4) << "...Replace call to Func \"" << c->name << "\" with "
<< "\"" << subs->name << "\"\n";
expr = Call::make(c->type, subs->name, c->args, c->call_type,
subs, c->value_index,
c->image, c->param);
}
return expr;
}};

contents->mutate(&m);
return *this;
}

Expand Down
10 changes: 10 additions & 0 deletions src/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ namespace Internal {

class Function;

// The lambda-based IRVisitors and IRMutators use this helper to dispatch
// to multiple lambdas via function overloading (of `operator()`).
template<typename... Ts>
struct LambdaOverloads : Ts... {
using Ts::operator()...;
explicit LambdaOverloads(Ts... ts)
: Ts(std::move(ts))... {
}
};

/** The actual IR nodes begin here. Remember that all the Expr
* nodes also have a public "type" property */

Expand Down
Loading
Loading