diff --git a/src/AsyncProducers.cpp b/src/AsyncProducers.cpp index bb5e4279d367..9c71f8a4d5fe 100644 --- a/src/AsyncProducers.cpp +++ b/src/AsyncProducers.cpp @@ -146,26 +146,12 @@ class GenerateProducerBody : public NoOpCollapsingMutator { // to schedule those Funcs as async too. Check for any consume nodes // where the producer has gone to the consumer side of the fork // node. - class FindBadConsumeNodes : public IRVisitor { - const std::set &producers_dropped; - using IRVisitor::visit; - - void visit(const ProducerConsumer *op) override { - if (!op->is_producer && producers_dropped.count(op->name)) { - found = op->name; - } + visit_with(body, [&](auto *self, const ProducerConsumer *op) { + if (!op->is_producer && producers_dropped.count(op->name)) { + bad_producer_nesting_error(op->name, func); } - - public: - string found; - FindBadConsumeNodes(const std::set &p) - : producers_dropped(p) { - } - } finder(producers_dropped); - body.accept(&finder); - if (!finder.found.empty()) { - bad_producer_nesting_error(finder.found, func); - } + self->visit_base(op); + }); while (!sema.empty()) { Expr release = Call::make(Int(32), "halide_semaphore_release", {sema.back(), 1}, Call::Extern); diff --git a/src/Closure.cpp b/src/Closure.cpp index 5c5125a9b291..42a9870c41fe 100644 --- a/src/Closure.cpp +++ b/src/Closure.cpp @@ -139,19 +139,16 @@ Stmt Closure::unpack_from_struct(const Expr &e, const Stmt &s) const { Expr packed = pack_into_struct(); // Make a prototype of the packed struct - class ReplaceCallArgsWithZero : public IRMutator { - public: - using IRMutator::mutate; - Expr mutate(const Expr &e) override { - if (!e.as()) { - return make_zero(e.type()); - } else { - return IRMutator::mutate(e); - } - } - } replacer; + Expr prototype = + mutate_with(packed, + [](auto *self, const Expr &e) { + if (!e.as()) { + return make_zero(e.type()); + } else { + return self->mutate_base(e); + } + }); string prototype_name = unique_name("closure_prototype"); - Expr prototype = replacer.mutate(packed); Expr prototype_var = Variable::make(Handle(), prototype_name); const Call *c = packed.as(); diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index d8e77c439707..d31e97e30427 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -774,26 +774,19 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::add_kernel(const Stmt &s, } } - class FindShared : public IRVisitor { - using IRVisitor::visit; - void visit(const Allocate *op) override { - if (op->memory_type == MemoryType::GPUShared) { - internal_assert(alloc == nullptr) - << "Found multiple shared allocations in metal kernel\n"; - alloc = op; - } + const Allocate *shared_alloc = nullptr; + shared_name = "__shared"; + visit_with(s, [&](auto *self, const Allocate *op) { + if (op->memory_type == MemoryType::GPUShared) { + internal_assert(shared_alloc == nullptr) + << "Found multiple shared allocations in metal kernel\n"; + shared_alloc = op; + shared_name = op->name; } + // Recurse just to make sure there aren't multiple nested shared allocs + self->visit_base(op); + }); - public: - const Allocate *alloc = nullptr; - } find_shared; - s.accept(&find_shared); - - if (find_shared.alloc) { - shared_name = find_shared.alloc->name; - } else { - shared_name = "__shared"; - } // Note that int4 below is an int32x4, not an int4_t. The type // is chosen to be large to maximize alignment. stream << ",\n" diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 07a1fd4bc279..1c945efb7cc1 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -1060,26 +1060,19 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::add_kernel(Stmt s, } } - class FindShared : public IRVisitor { - using IRVisitor::visit; - void visit(const Allocate *op) override { - if (op->memory_type == MemoryType::GPUShared) { - internal_assert(alloc == nullptr) - << "Found multiple shared allocations in opencl kernel\n"; - alloc = op; - } + const Allocate *shared_alloc = nullptr; + shared_name = "__shared"; + visit_with(s, [&](auto *self, const Allocate *op) { + if (op->memory_type == MemoryType::GPUShared) { + internal_assert(shared_alloc == nullptr) + << "Found multiple shared allocations in metal kernel\n"; + shared_alloc = op; + shared_name = op->name; } + // Recurse just to make sure there aren't multiple nested shared allocs + self->visit_base(op); + }); - public: - const Allocate *alloc = nullptr; - } find_shared; - s.accept(&find_shared); - - if (find_shared.alloc) { - shared_name = find_shared.alloc->name; - } else { - shared_name = "__shared"; - } // Note that int16 below is an int32x16, not an int16_t. The type // is chosen to be large to maximize alignment. stream << ",\n" diff --git a/src/IRMutator.h b/src/IRMutator.h index 95a396533d12..c170b37eb42b 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -335,7 +335,11 @@ auto mutate_with(const T &ir, Lambdas &&...lambdas) { std::is_invocable_v) { return LambdaMutatorGeneric{std::forward(lambdas)...}.mutate(ir); } else { - return LambdaMutator{std::forward(lambdas)...}.mutate(ir); + LambdaMutator mutator{std::forward(lambdas)...}; + constexpr bool all_take_two_args = + (std::is_invocable_v && ...); + static_assert(all_take_two_args); + return mutator.mutate(ir); } } diff --git a/src/IRVisitor.h b/src/IRVisitor.h index 7a64a6b16d51..473f825335ec 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -252,10 +252,18 @@ struct LambdaVisitor final : IRVisitor { } }; -template -void visit_with(const T &ir, Lambdas &&...lambdas) { +template +void visit_with(const IRNode *ir, Lambdas &&...lambdas) { LambdaVisitor visitor{std::forward(lambdas)...}; - ir.accept(&visitor); + constexpr bool all_take_two_args = + (std::is_invocable_v && ...); + static_assert(all_take_two_args); + ir->accept(&visitor); +} + +template +void visit_with(const IRHandle &ir, Lambdas &&...lambdas) { + visit_with(ir.get(), std::forward(lambdas)...); } /** A base class for algorithms that walk recursively over the IR diff --git a/src/InjectHostDevBufferCopies.cpp b/src/InjectHostDevBufferCopies.cpp index f7751552b16b..344a55c8e6d2 100644 --- a/src/InjectHostDevBufferCopies.cpp +++ b/src/InjectHostDevBufferCopies.cpp @@ -381,26 +381,16 @@ class InjectBufferCopiesForSingleBuffer : public IRMutator { return do_copies(op); } - // Check if a stmt has any for loops (and hence possible device - // transitions). - class HasLoops : public IRVisitor { - using IRVisitor::visit; - void visit(const For *op) override { - result = true; - } - - public: - bool result = false; - }; - Stmt visit(const Block *op) override { // If both sides of the block have no loops (and hence no // device transitions), treat it as a single leaf. This stops // host dirties from getting in between blocks of store stmts // that could be interleaved. - HasLoops loops; - op->accept(&loops); - if (loops.result) { + bool has_loops = false; + visit_with(op, [&](auto *, const For *op) { + has_loops = true; + }); + if (has_loops) { return IRMutator::visit(op); } else { return do_copies(op); diff --git a/src/LICM.cpp b/src/LICM.cpp index c73fcf5424ab..6d85c96b0531 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -278,16 +278,11 @@ class LICM : public IRMutator { } // Track the set of variables used by the inner loop - class CollectVars : public IRVisitor { - using IRVisitor::visit; - void visit(const Variable *op) override { - vars.insert(op->name); - } - - public: - set vars; - } vars; - new_stmt.accept(&vars); + set vars; + LambdaVisitor collect_var([&](auto *, const Variable *op) { + vars.insert(op->name); + }); + new_stmt.accept(&collect_var); // Now consider substituting back in each use const Call *call = dummy_call.as(); @@ -300,10 +295,10 @@ class LICM : public IRMutator { continue; } Expr e = call->args[i]; - if (cost(e, vars.vars) <= 1) { + if (cost(e, vars) <= 1) { // Just subs it back in - computing it is as cheap // as loading it. - e.accept(&vars); + e.accept(&collect_var); new_stmt = substitute(names[i], e, new_stmt); names[i].clear(); exprs[i] = Expr(); diff --git a/src/Lower.cpp b/src/Lower.cpp index c6db1adfa33c..34be927a2d4c 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -583,23 +583,19 @@ void lower_impl(const vector &output_funcs, // We're about to drop the environment and outputs vector, which // contain the only strong refs to Functions that may still be // pointed to by the IR. So make those refs strong. - class StrengthenRefs : public IRMutator { - using IRMutator::visit; - Expr visit(const Call *c) override { - Expr expr = IRMutator::visit(c); - c = expr.as(); - internal_assert(c); - if (c->func.defined()) { - FunctionPtr ptr = c->func; - ptr.strengthen(); - expr = Call::make(c->type, c->name, c->args, c->call_type, - ptr, c->value_index, - c->image, c->param); - } - return expr; + s = mutate_with(s, [&](auto *self, const Call *c) { + Expr expr = self->visit_base(c); + c = expr.as(); + internal_assert(c); + if (c->func.defined()) { + FunctionPtr ptr = c->func; + ptr.strengthen(); + expr = Call::make(c->type, c->name, c->args, c->call_type, + ptr, c->value_index, + c->image, c->param); } - }; - s = StrengthenRefs().mutate(s); + return expr; + }); LoweredFunc main_func(pipeline_name, public_args, s, linkage_type); diff --git a/src/Parameter.cpp b/src/Parameter.cpp index 722929185e67..635b924f1482 100644 --- a/src/Parameter.cpp +++ b/src/Parameter.cpp @@ -274,49 +274,27 @@ namespace { using namespace Halide::Internal; Expr remove_self_references(const Parameter &p, const Expr &e) { - class RemoveSelfReferences : public IRMutator { - using IRMutator::visit; - - Expr visit(const Variable *var) override { - if (var->param.same_as(p)) { - internal_assert(starts_with(var->name, p.name() + ".")); - return Variable::make(var->type, var->name); - } else { - internal_assert(!starts_with(var->name, p.name() + ".")); - } - return var; - } - - public: - const Parameter &p; - RemoveSelfReferences(const Parameter &p) - : p(p) { + return mutate_with(e, [&](auto *self, const Variable *var) -> Expr { + if (var->param.same_as(p)) { + internal_assert(starts_with(var->name, p.name() + ".")); + return Variable::make(var->type, var->name); + } else { + internal_assert(!starts_with(var->name, p.name() + ".")); } - } mutator{p}; - return mutator.mutate(e); + return var; + }); } Expr restore_self_references(const Parameter &p, const Expr &e) { - class RestoreSelfReferences : public IRMutator { - using IRMutator::visit; - - Expr visit(const Variable *var) override { - if (!var->image.defined() && - !var->param.defined() && - !var->reduction_domain.defined() && - Internal::starts_with(var->name, p.name() + ".")) { - return Internal::Variable::make(var->type, var->name, p); - } - return var; - } - - public: - const Parameter &p; - RestoreSelfReferences(const Parameter &p) - : p(p) { + return mutate_with(e, [&](auto *self, const Variable *var) -> Expr { + if (!var->image.defined() && + !var->param.defined() && + !var->reduction_domain.defined() && + Internal::starts_with(var->name, p.name() + ".")) { + return Internal::Variable::make(var->type, var->name, p); } - } mutator{p}; - return mutator.mutate(e); + return var; + }); } } // namespace diff --git a/src/PartitionLoops.cpp b/src/PartitionLoops.cpp index cc9bda77d586..1b88ff829fe4 100644 --- a/src/PartitionLoops.cpp +++ b/src/PartitionLoops.cpp @@ -1167,19 +1167,15 @@ Stmt partition_loops(Stmt s) { s = LowerLikelyIfInnermost().mutate(s); // Walk inwards to the first loop before doing any more work. - class Mutator : public IRMutator { - using IRMutator::visit; - Stmt visit(const For *op) override { - Stmt s = op; - s = MarkClampedRampsAsLikely().mutate(s); - s = ExpandSelects().mutate(s); - s = PartitionLoops().mutate(s); - s = RenormalizeGPULoops().mutate(s); - s = CollapseSelects().mutate(s); - return s; - } - } mutator; - s = mutator.mutate(s); + s = mutate_with(s, [](auto *self, const For *op) { + Stmt s = op; + s = MarkClampedRampsAsLikely().mutate(s); + s = ExpandSelects().mutate(s); + s = PartitionLoops().mutate(s); + s = RenormalizeGPULoops().mutate(s); + s = CollapseSelects().mutate(s); + return s; + }); s = remove_likelies(s); return s; diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 48742753ec7b..09b61b95265c 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -325,30 +325,16 @@ template T substitute_facts_impl(const T &t, const std::set &truths, const std::set &falsehoods) { - class Substitutor : public IRMutator { - const std::set &truths, &falsehoods; - - public: - using IRMutator::mutate; - Expr mutate(const Expr &e) override { - if (!e.type().is_bool()) { - return IRMutator::mutate(e); - } else if (truths.count(e)) { + return mutate_with(t, [&](auto *self, const Expr &e) { + if (e.type().is_bool()) { + if (truths.count(e)) { return make_one(e.type()); } else if (falsehoods.count(e)) { return make_zero(e.type()); - } else { - return IRMutator::mutate(e); } } - - Substitutor(const std::set &t, - const std::set &f) - : truths(t), falsehoods(f) { - } - } substitutor(truths, falsehoods); - - return substitutor.mutate(t); + return self->mutate_base(e); + }); } } // namespace diff --git a/src/SkipStages.cpp b/src/SkipStages.cpp index 63640492de36..ec90a720718a 100644 --- a/src/SkipStages.cpp +++ b/src/SkipStages.cpp @@ -378,21 +378,15 @@ class SkipStages : public IRMutator { // Is an Expr safe to lift into a .used or .loaded condition. bool may_lift(const Expr &e) { - class MayLift : public IRVisitor { - using IRVisitor::visit; - void visit(const Call *op) override { - if (!op->is_pure() && op->call_type != Call::Halide) { - result = false; - } else { - IRVisitor::visit(op); - } + bool result = true; + visit_with(e, [&](auto *self, const Call *op) { + if (!op->is_pure() && op->call_type != Call::Halide) { + result = false; + } else { + self->visit_base(op); } - - public: - bool result = true; - } v; - e.accept(&v); - return v.result; + }); + return result; } // Come up with an upper bound for the truth value of an expression with the @@ -411,30 +405,17 @@ class SkipStages : public IRMutator { // Come up with an upper bound for the truth value of an expression with any // calls to the given func eliminated. Expr relax_over_calls(const Expr &e, const std::string &func) { - class ReplaceCalls : public IRMutator { - const std::string &func; - - using IRMutator::visit; - - Expr visit(const Call *op) override { - if (op->call_type == Call::Halide && op->name == func) { - return cast(op->type, var); - } - return IRMutator::visit(op); - } - - public: - const std::string var_name; - const Expr var; - - ReplaceCalls(const std::string &func) - : func(func), - var_name(unique_name('t')), - var(Variable::make(Int(32), var_name)) { - } - } replacer(func); - - return relax_over_var(replacer.mutate(e), replacer.var_name); + const std::string var_name = unique_name('t'); + const Expr var = Variable::make(Int(32), var_name); + Expr relaxed = + mutate_with(e, + [&](auto *self, const Call *op) { + if (op->call_type == Call::Halide && op->name == func) { + return cast(op->type, var); + } + return self->visit_base(op); + }); + return relax_over_var(relaxed, var_name); } Expr visit(const Select *op) override { diff --git a/src/StripAsserts.cpp b/src/StripAsserts.cpp index 9d9c667f4db1..e8f101fcfc0c 100644 --- a/src/StripAsserts.cpp +++ b/src/StripAsserts.cpp @@ -10,39 +10,34 @@ namespace Internal { namespace { bool may_discard(const Expr &e) { - class MayDiscard : public IRVisitor { - using IRVisitor::visit; - - void visit(const Call *op) override { - // Extern calls that are side-effecty in the sense that you can't - // move them around in the IR, but we're free to discard because - // they're just getters. - static const std::set discardable{ - Call::buffer_get_dimensions, - Call::buffer_get_min, - Call::buffer_get_extent, - Call::buffer_get_stride, - Call::buffer_get_max, - Call::buffer_get_host, - Call::buffer_get_device, - Call::buffer_get_device_interface, - Call::buffer_get_shape, - Call::buffer_get_host_dirty, - Call::buffer_get_device_dirty, - Call::buffer_get_type}; - - if (!(op->is_pure() || - discardable.count(op->name))) { - result = false; - } + bool result = true; + // Extern calls that are side-effecty in the sense that you can't + // move them around in the IR, but we're free to discard because + // they're just getters. + static const std::set discardable{ + Call::buffer_get_dimensions, + Call::buffer_get_min, + Call::buffer_get_extent, + Call::buffer_get_stride, + Call::buffer_get_max, + Call::buffer_get_host, + Call::buffer_get_device, + Call::buffer_get_device_interface, + Call::buffer_get_shape, + Call::buffer_get_host_dirty, + Call::buffer_get_device_dirty, + Call::buffer_get_type}; + + visit_with(e, [&](auto *self, const Call *op) { + if (!(op->is_pure() || + discardable.count(op->name))) { + result = false; + } else { + self->visit_base(op); } + }); - public: - bool result = true; - } d; - e.accept(&d); - - return d.result; + return result; } class StripAsserts : public IRMutator {