Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions src/AsyncProducers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,26 +146,12 @@ class GenerateProducerBody : public NoOpCollapsingMutator {
// to schedule those Funcs as async too. Check for any consume nodes
// where the producer has gone to the consumer side of the fork
// node.
class FindBadConsumeNodes : public IRVisitor {
const std::set<string> &producers_dropped;
using IRVisitor::visit;

void visit(const ProducerConsumer *op) override {
if (!op->is_producer && producers_dropped.count(op->name)) {
found = op->name;
}
visit_with(body, [&](auto *self, const ProducerConsumer *op) {
if (!op->is_producer && producers_dropped.count(op->name)) {
bad_producer_nesting_error(op->name, func);
}

public:
string found;
FindBadConsumeNodes(const std::set<string> &p)
: producers_dropped(p) {
}
} finder(producers_dropped);
body.accept(&finder);
if (!finder.found.empty()) {
bad_producer_nesting_error(finder.found, func);
}
self->visit_base(op);
});

while (!sema.empty()) {
Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern);
Expand Down
21 changes: 9 additions & 12 deletions src/Closure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,16 @@ Stmt Closure::unpack_from_struct(const Expr &e, const Stmt &s) const {
Expr packed = pack_into_struct();

// Make a prototype of the packed struct
class ReplaceCallArgsWithZero : public IRMutator {
public:
using IRMutator::mutate;
Expr mutate(const Expr &e) override {
if (!e.as<Call>()) {
return make_zero(e.type());
} else {
return IRMutator::mutate(e);
}
}
} replacer;
Expr prototype =
mutate_with(packed,
[](auto *self, const Expr &e) {
if (!e.as<Call>()) {
return make_zero(e.type());
} else {
return self->mutate_base(e);
}
});
string prototype_name = unique_name("closure_prototype");
Expr prototype = replacer.mutate(packed);
Expr prototype_var = Variable::make(Handle(), prototype_name);

const Call *c = packed.as<Call>();
Expand Down
29 changes: 11 additions & 18 deletions src/CodeGen_Metal_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -774,26 +774,19 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::add_kernel(const Stmt &s,
}
}

class FindShared : public IRVisitor {
using IRVisitor::visit;
void visit(const Allocate *op) override {
if (op->memory_type == MemoryType::GPUShared) {
internal_assert(alloc == nullptr)
<< "Found multiple shared allocations in metal kernel\n";
alloc = op;
}
const Allocate *shared_alloc = nullptr;
shared_name = "__shared";
visit_with(s, [&](auto *self, const Allocate *op) {
if (op->memory_type == MemoryType::GPUShared) {
internal_assert(shared_alloc == nullptr)
<< "Found multiple shared allocations in metal kernel\n";
shared_alloc = op;
shared_name = op->name;
}
// Recurse just to make sure there aren't multiple nested shared allocs
self->visit_base(op);
});

public:
const Allocate *alloc = nullptr;
} find_shared;
s.accept(&find_shared);

if (find_shared.alloc) {
shared_name = find_shared.alloc->name;
} else {
shared_name = "__shared";
}
// Note that int4 below is an int32x4, not an int4_t. The type
// is chosen to be large to maximize alignment.
stream << ",\n"
Expand Down
29 changes: 11 additions & 18 deletions src/CodeGen_OpenCL_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,26 +1060,19 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::add_kernel(Stmt s,
}
}

class FindShared : public IRVisitor {
using IRVisitor::visit;
void visit(const Allocate *op) override {
if (op->memory_type == MemoryType::GPUShared) {
internal_assert(alloc == nullptr)
<< "Found multiple shared allocations in opencl kernel\n";
alloc = op;
}
const Allocate *shared_alloc = nullptr;
shared_name = "__shared";
visit_with(s, [&](auto *self, const Allocate *op) {
if (op->memory_type == MemoryType::GPUShared) {
internal_assert(shared_alloc == nullptr)
<< "Found multiple shared allocations in metal kernel\n";
shared_alloc = op;
shared_name = op->name;
}
// Recurse just to make sure there aren't multiple nested shared allocs
self->visit_base(op);
});

public:
const Allocate *alloc = nullptr;
} find_shared;
s.accept(&find_shared);

if (find_shared.alloc) {
shared_name = find_shared.alloc->name;
} else {
shared_name = "__shared";
}
// Note that int16 below is an int32x16, not an int16_t. The type
// is chosen to be large to maximize alignment.
stream << ",\n"
Expand Down
6 changes: 5 additions & 1 deletion src/IRMutator.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,11 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) {
std::is_invocable_v<Overloads, Generic *, const Stmt &>) {
return LambdaMutatorGeneric{std::forward<Lambdas>(lambdas)...}.mutate(ir);
} else {
return LambdaMutator{std::forward<Lambdas>(lambdas)...}.mutate(ir);
LambdaMutator mutator{std::forward<Lambdas>(lambdas)...};
constexpr bool all_take_two_args =
(std::is_invocable_v<Lambdas, decltype(&mutator), decltype(nullptr)> && ...);
static_assert(all_take_two_args);
return mutator.mutate(ir);
}
}

Expand Down
14 changes: 11 additions & 3 deletions src/IRVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,18 @@ struct LambdaVisitor final : IRVisitor {
}
};

template<typename T, typename... Lambdas>
void visit_with(const T &ir, Lambdas &&...lambdas) {
template<typename... Lambdas>
void visit_with(const IRNode *ir, Lambdas &&...lambdas) {
LambdaVisitor visitor{std::forward<Lambdas>(lambdas)...};
ir.accept(&visitor);
constexpr bool all_take_two_args =
(std::is_invocable_v<Lambdas, decltype(&visitor), decltype(nullptr)> && ...);
static_assert(all_take_two_args);
ir->accept(&visitor);
}

template<typename... Lambdas>
void visit_with(const IRHandle &ir, Lambdas &&...lambdas) {
visit_with(ir.get(), std::forward<Lambdas>(lambdas)...);
}

/** A base class for algorithms that walk recursively over the IR
Expand Down
20 changes: 5 additions & 15 deletions src/InjectHostDevBufferCopies.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,26 +381,16 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator {
return do_copies(op);
}

// Check if a stmt has any for loops (and hence possible device
// transitions).
class HasLoops : public IRVisitor {
using IRVisitor::visit;
void visit(const For *op) override {
result = true;
}

public:
bool result = false;
};

Stmt visit(const Block *op) override {
// If both sides of the block have no loops (and hence no
// device transitions), treat it as a single leaf. This stops
// host dirties from getting in between blocks of store stmts
// that could be interleaved.
HasLoops loops;
op->accept(&loops);
if (loops.result) {
bool has_loops = false;
visit_with(op, [&](auto *, const For *op) {
has_loops = true;
});
if (has_loops) {
return IRMutator::visit(op);
} else {
return do_copies(op);
Expand Down
19 changes: 7 additions & 12 deletions src/LICM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,16 +278,11 @@ class LICM : public IRMutator {
}

// Track the set of variables used by the inner loop
class CollectVars : public IRVisitor {
using IRVisitor::visit;
void visit(const Variable *op) override {
vars.insert(op->name);
}

public:
set<string> vars;
} vars;
new_stmt.accept(&vars);
set<string> vars;
LambdaVisitor collect_var([&](auto *, const Variable *op) {
vars.insert(op->name);
});
new_stmt.accept(&collect_var);

// Now consider substituting back in each use
const Call *call = dummy_call.as<Call>();
Expand All @@ -300,10 +295,10 @@ class LICM : public IRMutator {
continue;
}
Expr e = call->args[i];
if (cost(e, vars.vars) <= 1) {
if (cost(e, vars) <= 1) {
// Just subs it back in - computing it is as cheap
// as loading it.
e.accept(&vars);
e.accept(&collect_var);
new_stmt = substitute(names[i], e, new_stmt);
names[i].clear();
exprs[i] = Expr();
Expand Down
28 changes: 12 additions & 16 deletions src/Lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,23 +583,19 @@ void lower_impl(const vector<Function> &output_funcs,
// We're about to drop the environment and outputs vector, which
// contain the only strong refs to Functions that may still be
// pointed to by the IR. So make those refs strong.
class StrengthenRefs : public IRMutator {
using IRMutator::visit;
Expr visit(const Call *c) override {
Expr expr = IRMutator::visit(c);
c = expr.as<Call>();
internal_assert(c);
if (c->func.defined()) {
FunctionPtr ptr = c->func;
ptr.strengthen();
expr = Call::make(c->type, c->name, c->args, c->call_type,
ptr, c->value_index,
c->image, c->param);
}
return expr;
s = mutate_with(s, [&](auto *self, const Call *c) {
Expr expr = self->visit_base(c);
c = expr.as<Call>();
internal_assert(c);
if (c->func.defined()) {
FunctionPtr ptr = c->func;
ptr.strengthen();
expr = Call::make(c->type, c->name, c->args, c->call_type,
ptr, c->value_index,
c->image, c->param);
}
};
s = StrengthenRefs().mutate(s);
return expr;
});

LoweredFunc main_func(pipeline_name, public_args, s, linkage_type);

Expand Down
54 changes: 16 additions & 38 deletions src/Parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,49 +274,27 @@ namespace {
using namespace Halide::Internal;

Expr remove_self_references(const Parameter &p, const Expr &e) {
class RemoveSelfReferences : public IRMutator {
using IRMutator::visit;

Expr visit(const Variable *var) override {
if (var->param.same_as(p)) {
internal_assert(starts_with(var->name, p.name() + "."));
return Variable::make(var->type, var->name);
} else {
internal_assert(!starts_with(var->name, p.name() + "."));
}
return var;
}

public:
const Parameter &p;
RemoveSelfReferences(const Parameter &p)
: p(p) {
return mutate_with(e, [&](auto *self, const Variable *var) -> Expr {
if (var->param.same_as(p)) {
internal_assert(starts_with(var->name, p.name() + "."));
return Variable::make(var->type, var->name);
} else {
internal_assert(!starts_with(var->name, p.name() + "."));
}
} mutator{p};
return mutator.mutate(e);
return var;
});
}

Expr restore_self_references(const Parameter &p, const Expr &e) {
class RestoreSelfReferences : public IRMutator {
using IRMutator::visit;

Expr visit(const Variable *var) override {
if (!var->image.defined() &&
!var->param.defined() &&
!var->reduction_domain.defined() &&
Internal::starts_with(var->name, p.name() + ".")) {
return Internal::Variable::make(var->type, var->name, p);
}
return var;
}

public:
const Parameter &p;
RestoreSelfReferences(const Parameter &p)
: p(p) {
return mutate_with(e, [&](auto *self, const Variable *var) -> Expr {
if (!var->image.defined() &&
!var->param.defined() &&
!var->reduction_domain.defined() &&
Internal::starts_with(var->name, p.name() + ".")) {
return Internal::Variable::make(var->type, var->name, p);
}
} mutator{p};
return mutator.mutate(e);
return var;
});
}

} // namespace
Expand Down
22 changes: 9 additions & 13 deletions src/PartitionLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1167,19 +1167,15 @@ Stmt partition_loops(Stmt s) {
s = LowerLikelyIfInnermost().mutate(s);

// Walk inwards to the first loop before doing any more work.
class Mutator : public IRMutator {
using IRMutator::visit;
Stmt visit(const For *op) override {
Stmt s = op;
s = MarkClampedRampsAsLikely().mutate(s);
s = ExpandSelects().mutate(s);
s = PartitionLoops().mutate(s);
s = RenormalizeGPULoops().mutate(s);
s = CollapseSelects().mutate(s);
return s;
}
} mutator;
s = mutator.mutate(s);
s = mutate_with(s, [](auto *self, const For *op) {
Stmt s = op;
s = MarkClampedRampsAsLikely().mutate(s);
s = ExpandSelects().mutate(s);
s = PartitionLoops().mutate(s);
s = RenormalizeGPULoops().mutate(s);
s = CollapseSelects().mutate(s);
return s;
});

s = remove_likelies(s);
return s;
Expand Down
Loading
Loading