diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 670ccf11d177..e00d76ea6c9f 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -2408,27 +2408,15 @@ class BoxesTouched : public IRGraphVisitor { } } - class CountVars : public IRVisitor { - using IRVisitor::visit; - - void visit(const Variable *var) override { - count++; - } - - public: - int count = 0; - CountVars() = default; - }; - // We get better simplification if we directly substitute mins // and maxes in, but this can also cause combinatorial code // explosion. Here we manage this by only substituting in // reasonably-sized expressions. We determine the size by // counting the number of var nodes. bool is_small_enough_to_substitute(const Expr &e) { - CountVars c; - e.accept(&c); - return c.count < 10; + int count = 0; + visit_with(e, [&](auto *, const Variable *) { count++; }); + return count < 10; } void push_var(const string &name) { @@ -2744,23 +2732,13 @@ class BoxesTouched : public IRGraphVisitor { } vector find_free_vars(const Expr &e) { - class FindFreeVars : public IRVisitor { - using IRVisitor::visit; - void visit(const Variable *op) override { - if (scope.contains(op->name)) { - result.push_back(op); - } + vector result; + visit_with(e, [&](auto *, const Variable *op) { + if (scope.contains(op->name)) { + result.push_back(op); } - - public: - const Scope &scope; - vector result; - FindFreeVars(const Scope &s) - : scope(s) { - } - } finder(scope); - e.accept(&finder); - return finder.result; + }); + return result; } void visit(const IfThenElse *op) override { diff --git a/src/BoundsInference.cpp b/src/BoundsInference.cpp index 9ba1f1af2019..2b5f56336ae5 100644 --- a/src/BoundsInference.cpp +++ b/src/BoundsInference.cpp @@ -34,40 +34,30 @@ bool var_name_match(const string &candidate, const string &var) { return (candidate == var) || Internal::ends_with(candidate, "." + var); } -class DependsOnBoundsInference : public IRVisitor { - using IRVisitor::visit; - - void visit(const Variable *var) override { - if (ends_with(var->name, ".max") || - ends_with(var->name, ".min")) { - result = true; - } - } - - void visit(const Call *op) override { - if (op->name == Call::buffer_get_min || - op->name == Call::buffer_get_max) { - result = true; - } else { - IRVisitor::visit(op); - } - } - -public: - bool result = false; - DependsOnBoundsInference() = default; -}; - bool depends_on_bounds_inference(const Expr &e) { - DependsOnBoundsInference d; - e.accept(&d); - return d.result; + bool result = false; + visit_with( + e, + [&](auto *, const Variable *var) { + if (ends_with(var->name, ".max") || + ends_with(var->name, ".min")) { + result = true; + } // + }, + [&](auto *self, const Call *op) { + if (op->name == Call::buffer_get_min || + op->name == Call::buffer_get_max) { + result = true; + } else { + self->visit_base(op); + } // + }); + return result; } /** Compute the bounds of the value of some variable defined by an * inner let stmt or for loop. E.g. for the stmt: * - * * for x from 0 to 10: * let y = x + 2; * diff --git a/src/Function.cpp b/src/Function.cpp index f66b886cdafd..54fe96f785f2 100644 --- a/src/Function.cpp +++ b/src/Function.cpp @@ -1250,50 +1250,36 @@ const Call *Function::is_wrapper() const { } } -namespace { - -// Replace all calls to functions listed in 'substitutions' with their wrappers. -class SubstituteCalls : public IRMutator { - using IRMutator::visit; - - const map &substitutions; - - Expr visit(const Call *c) override { - Expr expr = IRMutator::visit(c); - c = expr.as(); - internal_assert(c); - - if ((c->call_type == Call::Halide) && - c->func.defined() && - substitutions.count(c->func)) { - auto it = substitutions.find(c->func); - internal_assert(it != substitutions.end()) - << "Function not in environment: " << c->func->name << "\n"; - FunctionPtr subs = it->second; - debug(4) << "...Replace call to Func \"" << c->name << "\" with " - << "\"" << subs->name << "\"\n"; - expr = Call::make(c->type, subs->name, c->args, c->call_type, - subs, c->value_index, - c->image, c->param); - } - return expr; - } - -public: - SubstituteCalls(const map &substitutions) - : substitutions(substitutions) { - } -}; - -} // anonymous namespace - Function &Function::substitute_calls(const map &substitutions) { debug(4) << "Substituting calls in " << name() << "\n"; if (substitutions.empty()) { return *this; } - SubstituteCalls subs_calls(substitutions); - contents->mutate(&subs_calls); + + // Replace all calls to functions listed in 'substitutions' with their wrappers. + auto m = LambdaMutator{ + [&](auto *self, const Call *c) { + Expr expr = self->visit_base(c); + c = expr.as(); + internal_assert(c); + + if ((c->call_type == Call::Halide) && + c->func.defined() && + substitutions.count(c->func)) { + auto it = substitutions.find(c->func); + internal_assert(it != substitutions.end()) + << "Function not in environment: " << c->func->name << "\n"; + FunctionPtr subs = it->second; + debug(4) << "...Replace call to Func \"" << c->name << "\" with " + << "\"" << subs->name << "\"\n"; + expr = Call::make(c->type, subs->name, c->args, c->call_type, + subs, c->value_index, + c->image, c->param); + } + return expr; + }}; + + contents->mutate(&m); return *this; } diff --git a/src/IR.h b/src/IR.h index bdf42a75f7b1..3dce56d504c6 100644 --- a/src/IR.h +++ b/src/IR.h @@ -23,6 +23,16 @@ namespace Internal { class Function; +// The lambda-based IRVisitors and IRMutators use this helper to dispatch +// to multiple lambdas via function overloading (of `operator()`). +template +struct LambdaOverloads : Ts... { + using Ts::operator()...; + explicit LambdaOverloads(Ts... ts) + : Ts(std::move(ts))... { + } +}; + /** The actual IR nodes begin here. Remember that all the Expr * nodes also have a public "type" property */ diff --git a/src/IRMutator.h b/src/IRMutator.h index 4eed97ec4fe7..95a396533d12 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -6,6 +6,8 @@ */ #include +#include +#include #include "IR.h" @@ -57,7 +59,6 @@ class IRMutator { virtual Expr visit(const StringImm *); virtual Expr visit(const Cast *); virtual Expr visit(const Reinterpret *); - virtual Expr visit(const Variable *); virtual Expr visit(const Add *); virtual Expr visit(const Sub *); virtual Expr visit(const Mul *); @@ -78,31 +79,31 @@ class IRMutator { virtual Expr visit(const Load *); virtual Expr visit(const Ramp *); virtual Expr visit(const Broadcast *); - virtual Expr visit(const Call *); virtual Expr visit(const Let *); - virtual Expr visit(const Shuffle *); - virtual Expr visit(const VectorReduce *); - virtual Stmt visit(const LetStmt *); virtual Stmt visit(const AssertStmt *); virtual Stmt visit(const ProducerConsumer *); - virtual Stmt visit(const For *); virtual Stmt visit(const Store *); virtual Stmt visit(const Provide *); virtual Stmt visit(const Allocate *); virtual Stmt visit(const Free *); virtual Stmt visit(const Realize *); virtual Stmt visit(const Block *); + virtual Stmt visit(const Fork *); virtual Stmt visit(const IfThenElse *); virtual Stmt visit(const Evaluate *); - virtual Stmt visit(const Prefetch *); + virtual Expr visit(const Call *); + virtual Expr visit(const Variable *); + virtual Stmt visit(const For *); virtual Stmt visit(const Acquire *); - virtual Stmt visit(const Fork *); - virtual Stmt visit(const Atomic *); + virtual Expr visit(const Shuffle *); + virtual Stmt visit(const Prefetch *); virtual Stmt visit(const HoistedStorage *); + virtual Stmt visit(const Atomic *); + virtual Expr visit(const VectorReduce *); }; -/** A mutator that caches and reapplies previously-done mutations, so +/** A mutator that caches and reapplies previously done mutations so * that it can handle graphs of IR that have not had CSE done to * them. */ class IRGraphMutator : public IRMutator { @@ -111,14 +112,233 @@ class IRGraphMutator : public IRMutator { std::map stmt_replacements; public: + using IRMutator::mutate; Stmt mutate(const Stmt &s) override; Expr mutate(const Expr &e) override; +}; - std::vector mutate(const std::vector &exprs) { - return IRMutator::mutate(exprs); +/** A lambda-based IR mutator that accepts multiple lambdas for different + * node types. */ +template +struct LambdaMutator final : IRMutator { + explicit LambdaMutator(Lambdas... lambdas) + : handlers(std::move(lambdas)...) { + } + + /** Public helper to call the base visitor from lambdas. */ + template + auto visit_base(const T *op) { + return IRMutator::visit(op); + } + +private: + LambdaOverloads handlers; + + template + auto visit_impl(const T *op) { + if constexpr (std::is_invocable_v) { + return handlers(this, op); + } else { + return this->visit_base(op); + } + } + +protected: + Expr visit(const IntImm *op) override { + return this->visit_impl(op); + } + Expr visit(const UIntImm *op) override { + return this->visit_impl(op); + } + Expr visit(const FloatImm *op) override { + return this->visit_impl(op); + } + Expr visit(const StringImm *op) override { + return this->visit_impl(op); + } + Expr visit(const Cast *op) override { + return this->visit_impl(op); + } + Expr visit(const Reinterpret *op) override { + return this->visit_impl(op); + } + Expr visit(const Add *op) override { + return this->visit_impl(op); + } + Expr visit(const Sub *op) override { + return this->visit_impl(op); + } + Expr visit(const Mul *op) override { + return this->visit_impl(op); + } + Expr visit(const Div *op) override { + return this->visit_impl(op); + } + Expr visit(const Mod *op) override { + return this->visit_impl(op); + } + Expr visit(const Min *op) override { + return this->visit_impl(op); + } + Expr visit(const Max *op) override { + return this->visit_impl(op); + } + Expr visit(const EQ *op) override { + return this->visit_impl(op); + } + Expr visit(const NE *op) override { + return this->visit_impl(op); + } + Expr visit(const LT *op) override { + return this->visit_impl(op); + } + Expr visit(const LE *op) override { + return this->visit_impl(op); + } + Expr visit(const GT *op) override { + return this->visit_impl(op); + } + Expr visit(const GE *op) override { + return this->visit_impl(op); + } + Expr visit(const And *op) override { + return this->visit_impl(op); + } + Expr visit(const Or *op) override { + return this->visit_impl(op); + } + Expr visit(const Not *op) override { + return this->visit_impl(op); + } + Expr visit(const Select *op) override { + return this->visit_impl(op); + } + Expr visit(const Load *op) override { + return this->visit_impl(op); + } + Expr visit(const Ramp *op) override { + return this->visit_impl(op); + } + Expr visit(const Broadcast *op) override { + return this->visit_impl(op); + } + Expr visit(const Let *op) override { + return this->visit_impl(op); + } + Stmt visit(const LetStmt *op) override { + return this->visit_impl(op); + } + Stmt visit(const AssertStmt *op) override { + return this->visit_impl(op); + } + Stmt visit(const ProducerConsumer *op) override { + return this->visit_impl(op); + } + Stmt visit(const Store *op) override { + return this->visit_impl(op); + } + Stmt visit(const Provide *op) override { + return this->visit_impl(op); + } + Stmt visit(const Allocate *op) override { + return this->visit_impl(op); + } + Stmt visit(const Free *op) override { + return this->visit_impl(op); + } + Stmt visit(const Realize *op) override { + return this->visit_impl(op); + } + Stmt visit(const Block *op) override { + return this->visit_impl(op); + } + Stmt visit(const Fork *op) override { + return this->visit_impl(op); + } + Stmt visit(const IfThenElse *op) override { + return this->visit_impl(op); + } + Stmt visit(const Evaluate *op) override { + return this->visit_impl(op); + } + Expr visit(const Call *op) override { + return this->visit_impl(op); + } + Expr visit(const Variable *op) override { + return this->visit_impl(op); + } + Stmt visit(const For *op) override { + return this->visit_impl(op); + } + Stmt visit(const Acquire *op) override { + return this->visit_impl(op); + } + Expr visit(const Shuffle *op) override { + return this->visit_impl(op); + } + Stmt visit(const Prefetch *op) override { + return this->visit_impl(op); + } + Stmt visit(const HoistedStorage *op) override { + return this->visit_impl(op); + } + Stmt visit(const Atomic *op) override { + return this->visit_impl(op); + } + Expr visit(const VectorReduce *op) override { + return this->visit_impl(op); } }; +/** A lambda-based IR mutator that accepts multiple lambdas for overloading + * the base mutate() method. */ +template +struct LambdaMutatorGeneric final : IRMutator { + explicit LambdaMutatorGeneric(Lambdas... lambdas) + : handlers(std::move(lambdas)...) { + } + + /** Public helper to call the base mutator from lambdas. */ + // Note: C++26 introduces variadic friends: https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2024/p2893r3.html + // So the mutate_base API could be replaced with: + // friend Lambdas...; + template + auto mutate_base(const T &op) { + return IRMutator::mutate(op); + } + + Expr mutate(const Expr &e) override { + if constexpr (std::is_invocable_v) { + return handlers(this, e); + } else { + return this->mutate_base(e); + } + } + + Stmt mutate(const Stmt &e) override { + if constexpr (std::is_invocable_v) { + return handlers(this, e); + } else { + return this->mutate_base(e); + } + } + +private: + LambdaOverloads handlers; +}; + +template +auto mutate_with(const T &ir, Lambdas &&...lambdas) { + using Overloads = LambdaOverloads; + using Generic = LambdaMutatorGeneric; + if constexpr (std::is_invocable_v || + std::is_invocable_v) { + return LambdaMutatorGeneric{std::forward(lambdas)...}.mutate(ir); + } else { + return LambdaMutator{std::forward(lambdas)...}.mutate(ir); + } +} + /** A helper function for mutator-like things to mutate regions */ template std::pair mutate_region(Mutator *mutator, const Region &bounds, Args &&...args) { diff --git a/src/IROperator.cpp b/src/IROperator.cpp index ec832404e0eb..f1d2254abb27 100644 --- a/src/IROperator.cpp +++ b/src/IROperator.cpp @@ -1052,40 +1052,34 @@ Expr strided_ramp_base(const Expr &e, int stride) { namespace { // Replace a specified list of intrinsics with their first arg. -class RemoveIntrinsics : public IRMutator { - using IRMutator::visit; - const std::initializer_list &ops; - - Expr visit(const Call *op) override { - if (op->is_intrinsic(ops)) { - return mutate(op->args[0]); - } else { - return IRMutator::visit(op); - } - } - -public: - RemoveIntrinsics(const std::initializer_list &ops) - : ops(ops) { - } -}; +template +T remove_intrinsics(const T &e, const std::initializer_list &ops) { + return mutate_with( + e, + [&](auto *self, const Call *op) { + if (op->is_intrinsic(ops)) { + return self->mutate(op->args[0]); + } + return self->visit_base(op); + }); +} } // namespace Expr remove_likelies(const Expr &e) { - return RemoveIntrinsics({Call::likely, Call::likely_if_innermost}).mutate(e); + return remove_intrinsics(e, {Call::likely, Call::likely_if_innermost}); } Stmt remove_likelies(const Stmt &s) { - return RemoveIntrinsics({Call::likely, Call::likely_if_innermost}).mutate(s); + return remove_intrinsics(s, {Call::likely, Call::likely_if_innermost}); } Expr remove_promises(const Expr &e) { - return RemoveIntrinsics({Call::promise_clamped, Call::unsafe_promise_clamped}).mutate(e); + return remove_intrinsics(e, {Call::promise_clamped, Call::unsafe_promise_clamped}); } Stmt remove_promises(const Stmt &s) { - return RemoveIntrinsics({Call::promise_clamped, Call::unsafe_promise_clamped}).mutate(s); + return remove_intrinsics(s, {Call::promise_clamped, Call::unsafe_promise_clamped}); } Expr unwrap_tags(const Expr &e) { diff --git a/src/IRVisitor.h b/src/IRVisitor.h index 4be3b82a1454..7a64a6b16d51 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -35,7 +35,6 @@ class IRVisitor { virtual void visit(const StringImm *); virtual void visit(const Cast *); virtual void visit(const Reinterpret *); - virtual void visit(const Variable *); virtual void visit(const Add *); virtual void visit(const Sub *); virtual void visit(const Mul *); @@ -56,29 +55,209 @@ class IRVisitor { virtual void visit(const Load *); virtual void visit(const Ramp *); virtual void visit(const Broadcast *); - virtual void visit(const Call *); virtual void visit(const Let *); virtual void visit(const LetStmt *); virtual void visit(const AssertStmt *); virtual void visit(const ProducerConsumer *); - virtual void visit(const For *); virtual void visit(const Store *); virtual void visit(const Provide *); virtual void visit(const Allocate *); virtual void visit(const Free *); virtual void visit(const Realize *); virtual void visit(const Block *); + virtual void visit(const Fork *); virtual void visit(const IfThenElse *); virtual void visit(const Evaluate *); + virtual void visit(const Call *); + virtual void visit(const Variable *); + virtual void visit(const For *); + virtual void visit(const Acquire *); virtual void visit(const Shuffle *); - virtual void visit(const VectorReduce *); virtual void visit(const Prefetch *); - virtual void visit(const Fork *); - virtual void visit(const Acquire *); - virtual void visit(const Atomic *); virtual void visit(const HoistedStorage *); + virtual void visit(const Atomic *); + virtual void visit(const VectorReduce *); }; +/** A lambda-based IR visitor that accepts multiple lambdas for different + * node types. */ +template +struct LambdaVisitor final : IRVisitor { + explicit LambdaVisitor(Lambdas... lambdas) + : handlers(std::move(lambdas)...) { + } + + /** Public helper to call the base visitor from lambdas. */ + template + void visit_base(const T *op) { + IRVisitor::visit(op); + } + +private: + LambdaOverloads handlers; + + template + auto visit_impl(const T *op) { + if constexpr (std::is_invocable_v) { + return handlers(this, op); + } else { + return this->visit_base(op); + } + } + +protected: + void visit(const IntImm *op) override { + this->visit_impl(op); + } + void visit(const UIntImm *op) override { + this->visit_impl(op); + } + void visit(const FloatImm *op) override { + this->visit_impl(op); + } + void visit(const StringImm *op) override { + this->visit_impl(op); + } + void visit(const Cast *op) override { + this->visit_impl(op); + } + void visit(const Reinterpret *op) override { + this->visit_impl(op); + } + void visit(const Add *op) override { + this->visit_impl(op); + } + void visit(const Sub *op) override { + this->visit_impl(op); + } + void visit(const Mul *op) override { + this->visit_impl(op); + } + void visit(const Div *op) override { + this->visit_impl(op); + } + void visit(const Mod *op) override { + this->visit_impl(op); + } + void visit(const Min *op) override { + this->visit_impl(op); + } + void visit(const Max *op) override { + this->visit_impl(op); + } + void visit(const EQ *op) override { + this->visit_impl(op); + } + void visit(const NE *op) override { + this->visit_impl(op); + } + void visit(const LT *op) override { + this->visit_impl(op); + } + void visit(const LE *op) override { + this->visit_impl(op); + } + void visit(const GT *op) override { + this->visit_impl(op); + } + void visit(const GE *op) override { + this->visit_impl(op); + } + void visit(const And *op) override { + this->visit_impl(op); + } + void visit(const Or *op) override { + this->visit_impl(op); + } + void visit(const Not *op) override { + this->visit_impl(op); + } + void visit(const Select *op) override { + this->visit_impl(op); + } + void visit(const Load *op) override { + this->visit_impl(op); + } + void visit(const Ramp *op) override { + this->visit_impl(op); + } + void visit(const Broadcast *op) override { + this->visit_impl(op); + } + void visit(const Let *op) override { + this->visit_impl(op); + } + void visit(const LetStmt *op) override { + this->visit_impl(op); + } + void visit(const AssertStmt *op) override { + this->visit_impl(op); + } + void visit(const ProducerConsumer *op) override { + this->visit_impl(op); + } + void visit(const Store *op) override { + this->visit_impl(op); + } + void visit(const Provide *op) override { + this->visit_impl(op); + } + void visit(const Allocate *op) override { + this->visit_impl(op); + } + void visit(const Free *op) override { + this->visit_impl(op); + } + void visit(const Realize *op) override { + this->visit_impl(op); + } + void visit(const Block *op) override { + this->visit_impl(op); + } + void visit(const Fork *op) override { + this->visit_impl(op); + } + void visit(const IfThenElse *op) override { + this->visit_impl(op); + } + void visit(const Evaluate *op) override { + this->visit_impl(op); + } + void visit(const Call *op) override { + this->visit_impl(op); + } + void visit(const Variable *op) override { + this->visit_impl(op); + } + void visit(const For *op) override { + this->visit_impl(op); + } + void visit(const Acquire *op) override { + this->visit_impl(op); + } + void visit(const Shuffle *op) override { + this->visit_impl(op); + } + void visit(const Prefetch *op) override { + this->visit_impl(op); + } + void visit(const HoistedStorage *op) override { + this->visit_impl(op); + } + void visit(const Atomic *op) override { + this->visit_impl(op); + } + void visit(const VectorReduce *op) override { + this->visit_impl(op); + } +}; + +template +void visit_with(const T &ir, Lambdas &&...lambdas) { + LambdaVisitor visitor{std::forward(lambdas)...}; + ir.accept(&visitor); +} + /** A base class for algorithms that walk recursively over the IR * without visiting the same node twice. This is for passes that are * capable of interpreting the IR as a DAG instead of a tree. */ @@ -109,7 +288,6 @@ class IRGraphVisitor : public IRVisitor { void visit(const StringImm *) override; void visit(const Cast *) override; void visit(const Reinterpret *) override; - void visit(const Variable *) override; void visit(const Add *) override; void visit(const Sub *) override; void visit(const Mul *) override; @@ -130,27 +308,28 @@ class IRGraphVisitor : public IRVisitor { void visit(const Load *) override; void visit(const Ramp *) override; void visit(const Broadcast *) override; - void visit(const Call *) override; void visit(const Let *) override; void visit(const LetStmt *) override; void visit(const AssertStmt *) override; void visit(const ProducerConsumer *) override; - void visit(const For *) override; void visit(const Store *) override; void visit(const Provide *) override; void visit(const Allocate *) override; void visit(const Free *) override; void visit(const Realize *) override; void visit(const Block *) override; + void visit(const Fork *) override; void visit(const IfThenElse *) override; void visit(const Evaluate *) override; + void visit(const Call *) override; + void visit(const Variable *) override; + void visit(const For *) override; + void visit(const Acquire *) override; void visit(const Shuffle *) override; - void visit(const VectorReduce *) override; void visit(const Prefetch *) override; - void visit(const Acquire *) override; - void visit(const Fork *) override; - void visit(const Atomic *) override; void visit(const HoistedStorage *) override; + void visit(const Atomic *) override; + void visit(const VectorReduce *) override; // @} }; diff --git a/src/InjectHostDevBufferCopies.cpp b/src/InjectHostDevBufferCopies.cpp index 39037a39be7e..f7751552b16b 100644 --- a/src/InjectHostDevBufferCopies.cpp +++ b/src/InjectHostDevBufferCopies.cpp @@ -590,28 +590,21 @@ class InjectBufferCopies : public IRMutator { } }; - class FreeAfterLastUse : public IRMutator { - Stmt last_use; - Stmt free_stmt; - - public: + Stmt inject_free_after_last_use(Stmt body, const Stmt &last_use, const Stmt &free_stmt) { bool success = false; - using IRMutator::mutate; - - Stmt mutate(const Stmt &s) override { - if (s.same_as(last_use)) { - internal_assert(!success); - success = true; - return Block::make(last_use, free_stmt); - } else { - return IRMutator::mutate(s); - } - } - - FreeAfterLastUse(Stmt s, Stmt f) - : last_use(std::move(s)), free_stmt(std::move(f)) { - } - }; + body = mutate_with( + body, + [&](auto *self, const Stmt &op) { + if (op.same_as(last_use)) { + internal_assert(!success); + success = true; + return Block::make(last_use, free_stmt); + } + return self->mutate_base(op); + }); + internal_assert(success); + return body; + } Stmt visit(const Allocate *op) override { FindBufferUsage finder(op->name, DeviceAPI::Host); @@ -650,9 +643,7 @@ class InjectBufferCopies : public IRMutator { body.accept(&last_use); if (last_use.last_use.defined()) { Stmt device_free = call_extern_and_assert("halide_device_and_host_free", {buffer}); - FreeAfterLastUse free_injecter(last_use.last_use, device_free); - body = free_injecter.mutate(body); - internal_assert(free_injecter.success); + body = inject_free_after_last_use(body, last_use.last_use, device_free); } Expr device_interface = make_device_interface_call(touching_device, op->memory_type); @@ -674,9 +665,7 @@ class InjectBufferCopies : public IRMutator { body.accept(&last_use); if (last_use.last_use.defined()) { Stmt device_free = call_extern_and_assert("halide_device_free", {buffer}); - FreeAfterLastUse free_injecter(last_use.last_use, device_free); - body = free_injecter.mutate(body); - internal_assert(free_injecter.success); + body = inject_free_after_last_use(body, last_use.last_use, device_free); } Expr condition = op->condition; diff --git a/src/StorageFlattening.cpp b/src/StorageFlattening.cpp index 11f5e8e20d24..e7f9c324663e 100644 --- a/src/StorageFlattening.cpp +++ b/src/StorageFlattening.cpp @@ -447,31 +447,19 @@ class HoistStorage : public IRMutator { vector hoisted_storages; map hoisted_storages_map; - class ExpandExpr : public IRMutator { - using IRMutator::visit; - const Scope &scope; - - Expr visit(const Variable *var) override { - if (const Expr *e = scope.find(var->name)) { - // Mutate the expression, so lets can get replaced recursively. - Expr expr = mutate(*e); - debug(4) << "Fully expanded " << var->name << " -> " << expr << "\n"; - return expr; - } else { - return var; - } - } - - public: - ExpandExpr(const Scope &s) - : scope(s) { - } - }; - // Perform all the substitutions in a scope - Expr expand_expr(const Expr &e, const Scope &scope) { - ExpandExpr ee(scope); - Expr result = ee.mutate(e); + static Expr expand_expr(const Expr &e, const Scope &scope) { + Expr result = mutate_with( + e, + [&](auto *self, const Variable *var) -> Expr { + if (const Expr *value = scope.find(var->name)) { + // Mutate the expression, so lets can get replaced recursively. + Expr expr = self->mutate(*value); + debug(4) << "Fully expanded " << var->name << " -> " << expr << "\n"; + return expr; + } + return var; + }); debug(4) << "Expanded " << e << " into " << result << "\n"; return result; } diff --git a/src/Substitute.cpp b/src/Substitute.cpp index 6a7cba7fd589..e5e3eb59a298 100644 --- a/src/Substitute.cpp +++ b/src/Substitute.cpp @@ -126,35 +126,24 @@ Stmt substitute(const map &m, const Stmt &stmt) { namespace { -class SubstituteExpr : public IRMutator { -public: - Expr find, replacement; - - using IRMutator::mutate; - - Expr mutate(const Expr &e) override { +template +auto substitute_impl(const Expr &find, const Expr &replacement, const T &ir) { + return mutate_with(ir, [&](auto *self, const Expr &e) { if (equal(e, find)) { return replacement; - } else { - return IRMutator::mutate(e); } - } -}; + return self->mutate_base(e); + }); +} } // namespace Expr substitute(const Expr &find, const Expr &replacement, const Expr &expr) { - SubstituteExpr s; - s.find = find; - s.replacement = replacement; - return s.mutate(expr); + return substitute_impl(find, replacement, expr); } Stmt substitute(const Expr &find, const Expr &replacement, const Stmt &stmt) { - SubstituteExpr s; - s.find = find; - s.replacement = replacement; - return s.mutate(stmt); + return substitute_impl(find, replacement, stmt); } namespace {