diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp index 68d87c803147..005c5e0daed3 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp @@ -12,15 +12,19 @@ #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/Liveness.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/RegionUtils.h" #define DEBUG_TYPE "iree-stream-automatic-reference-counting" @@ -35,18 +39,20 @@ namespace { // Local analysis //===----------------------------------------------------------------------===// -// Block-local analysis for timepoint coverage. -struct LocalTimepointCoverage { +// Analysis for timepoint coverage within the current recursive analysis scope +// (e.g., a function or nested region). +struct ScopedTimepointCoverage { // Must be provided if LLVM_DEBUG is enabled. AsmState *asmState; // A map of timepoint SSA values to indices within the coverage map. - // Values from other blocks are omitted. Order is that of the appearance - // in the block but is not guaranteed and the value should only be used - // for indexing into the map. + // This map accumulates timepoints from the current analysis scope, including + // those from parent blocks/regions when the object is passed down. + // Order is that of the appearance in the block but is not guaranteed and the + // value should only be used for indexing into the map. DenseMap timepoints; - LocalTimepointCoverage() = delete; - LocalTimepointCoverage(AsmState *asmState) : asmState(asmState) {} + ScopedTimepointCoverage() = delete; + ScopedTimepointCoverage(AsmState *asmState) : asmState(asmState) {} // A matrix of timepoints by index to bits indicating whether another // timepoint (column) covers the entry (row). @@ -177,7 +183,7 @@ struct LastUseSet { // Must be provided if LLVM_DEBUG is enabled. AsmState *asmState; // Timepoint coverage map. - LocalTimepointCoverage &coverage; + ScopedTimepointCoverage &coverage; // Resource base value to a set of signal timepoints from users. // Each set of timepoints is maintained as only those not covered by // others. Multiple values indicates a fork. @@ -189,7 +195,7 @@ struct LastUseSet { SmallVector baseResourceOrder; LastUseSet() = delete; - LastUseSet(AsmState *asmState, LocalTimepointCoverage &coverage) + LastUseSet(AsmState *asmState, ScopedTimepointCoverage &coverage) : asmState(asmState), coverage(coverage) {} // Calls |fn| for each base resource produced within the analysis scope with @@ -300,8 +306,8 @@ struct LastUseSet { } }; -// Returns the last defined SSA value in the block in |timepoints|. -// All timepoints must be in the same block. +// Returns the last defined SSA value in the block in |timepoints| (textual +// order within the block). All timepoints must be in the same block. static Value getLastTimepointInBlock(TimepointSet &timepoints) { if (timepoints.empty()) { return nullptr; @@ -348,6 +354,641 @@ static StringRef getFuncName(FunctionOpInterface funcOp) { } } +// Conservatively marks all resources touched by an operation and its nested +// regions as indeterminate. Used as fallback for control flow we cannot +// analyze precisely. +static void markAllResourcesIndeterminateInOpAndRegions( + Operation &op, LastUseSet &lastUseSet, + DenseSet &indeterminateResources) { + // Mark all currently tracked resources. + lastUseSet.forEachResource([&](Value resource, TimepointSet &timepoints) { + indeterminateResources.insert(resource); + }); + + // Walk nested operations to find all resources used within regions. + op.walk([&](Operation *nestedOp) { + for (auto operand : nestedOp->getOperands()) { + if (isa(operand.getType())) { + Value baseResource = lastUseSet.lookupResource(operand); + indeterminateResources.insert(baseResource); + } + } + for (auto result : nestedOp->getResults()) { + if (isa(result.getType())) { + Value baseResource = lastUseSet.lookupResource(result); + indeterminateResources.insert(baseResource); + } + } + }); +} + +// Forward declarations for mutual recursion. +static bool analyzeRegionBranchOp(RegionBranchOpInterface regionBranchOp, + AsmState *asmState, LastUseSet &lastUseSet, + ScopedTimepointCoverage &coverage, + DenseSet &indeterminateResources, + DenseSet &handledResources); + +// Analyzes operations in a block, handling both timeline ops and nested +// control flow recursively. +static bool analyzeBlockOps(Block &block, AsmState *asmState, + LastUseSet &lastUseSet, + ScopedTimepointCoverage &coverage, + DenseSet &indeterminateResources, + DenseSet &handledResources) { + for (Operation &op : block) { + // Special case ops that are not timeline-aware but interoperate. + if (auto immediateOp = dyn_cast(op)) { + coverage.add(immediateOp.getResultTimepoint()); + continue; + } else if (auto importOp = dyn_cast(op)) { + coverage.add(importOp.getResultTimepoint()); + continue; + } + + // Handle resource lifetime management ops. + if (auto retainOp = dyn_cast(op)) { + handledResources.insert(lastUseSet.lookupResource(retainOp.getOperand())); + continue; + } else if (auto releaseOp = dyn_cast(op)) { + handledResources.insert( + lastUseSet.lookupResource(releaseOp.getOperand())); + continue; + } + + // Timeline ops are handled via the standard analysis. + // IMPORTANT: Check TimelineOpInterface BEFORE RegionBranchOpInterface + // because stream.cmd.execute implements both, and should be handled as a + // timeline op. + auto timelineOp = dyn_cast(op); + if (!timelineOp) { + // Handle structured control flow recursively (scf.for, scf.if, etc.). + if (auto regionBranchOp = dyn_cast(op)) { + if (!analyzeRegionBranchOp(regionBranchOp, asmState, lastUseSet, + coverage, indeterminateResources, + handledResources)) { + // Unknown control flow - fallback to conservative marking. + LLVM_DEBUG( + llvm::dbgs() + << "[arc] failed to analyze nested RegionBranchOpInterface " + << op.getName() << "; using conservative fallback\n"); + markAllResourcesIndeterminateInOpAndRegions(op, lastUseSet, + indeterminateResources); + } + continue; + } + + // Non-timeline, non-control-flow ops: mark resources indeterminate. + for (auto operand : op.getOperands()) { + if (isa(operand.getType())) { + indeterminateResources.insert(lastUseSet.lookupResource(operand)); + } + } + for (auto result : op.getResults()) { + if (isa(result.getType())) { + indeterminateResources.insert(lastUseSet.lookupResource(result)); + } + } + continue; + } + + // Process timeline op. + Value resultTimepoint = timelineOp.getResultTimepoint(); + if (!resultTimepoint) { + continue; + } + + // Populate coverage. + auto awaitTimepoints = timelineOp.getAwaitTimepoints(); + if (awaitTimepoints.empty()) { + coverage.add(resultTimepoint); + } else { + for (Value awaitTimepoint : awaitTimepoints) { + coverage.add(awaitTimepoint, resultTimepoint); + } + } + + // Check for explicitly indeterminate allocas. + if (auto allocaOp = dyn_cast(op)) { + if (allocaOp.getIndeterminateLifetime()) { + indeterminateResources.insert(allocaOp.getResult()); + } + } + + // Track existing deallocations. + if (auto deallocaOp = dyn_cast(op)) { + handledResources.insert( + lastUseSet.lookupResource(deallocaOp.getOperand())); + } + + // Track resource consumption/production. + auto tiedOp = dyn_cast(op); + for (auto operand : op.getOperands()) { + if (isa(operand.getType())) { + lastUseSet.consume(operand, resultTimepoint); + } + } + for (auto result : op.getResults()) { + if (isa(result.getType())) { + Value operand = tiedOp ? tiedOp.getTiedResultOperand(result) : nullptr; + if (operand) { + lastUseSet.tie(operand, result, resultTimepoint); + } else { + lastUseSet.produce(result, resultTimepoint); + // Mark non-alloca producers as indeterminate (#20817). + if (!isa(op)) { + indeterminateResources.insert(result); + } + } + } + } + } + return true; +} + +// Inserts deallocations for all resources tracked in the LastUseSet. +// This is called after analysis is complete to insert deallocation operations. +static void insertDeallocations(LastUseSet &lastUseSet, AsmState *asmState, + DenseSet &indeterminateResources, + DenseSet &handledResources) { + // Insert deallocations for all resources that we successfully analyzed. + lastUseSet.forEachResource([&](Value resource, TimepointSet &timepoints) { + assert(!timepoints.empty() && "all resources should have a timepoint"); + + // Skip anything we could not analyze. + Value baseResource = lastUseSet.lookupResource(resource); + if (indeterminateResources.contains(baseResource)) { + LLVM_DEBUG({ + llvm::dbgs() << "[arc] skipping resource "; + baseResource.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " marked as indeterminate\n"; + }); + return; + } else if (handledResources.contains(baseResource)) { + LLVM_DEBUG({ + llvm::dbgs() << "[arc] skipping resource "; + baseResource.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " marked as already handled\n"; + }); + return; + } + + // Finds the last timepoint in the set (if >1) in SSA dominance order. + Value lastTimepoint = getLastTimepointInBlock(timepoints); + assert(lastTimepoint && "must have at least one timepoint"); + OpBuilder builder(lastTimepoint.getContext()); + builder.setInsertionPointAfterValue(lastTimepoint); + + // Try to grab a resource size or insert a query. + // In almost all cases that this analysis can run we will have a + // size-aware op that provides it. + auto timepointsLoc = + getFusedLocFromTimepoints(resource.getLoc(), timepoints); + Value resourceSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( + timepointsLoc, resource, builder); + + // Lookup the affinity of the resource. + // This should likely be a global affinity analysis but since we are + // currently only processing locally we can assume this is only used when + // we have an op local with an affinity assigned. + IREE::Stream::AffinityAttr resourceAffinity; + if (auto *definingOp = resource.getDefiningOp()) { + resourceAffinity = IREE::Stream::AffinityAttr::lookup(definingOp); + } + UnitAttr preferOrigin = + resourceAffinity ? UnitAttr{} : builder.getUnitAttr(); + + if (timepoints.size() == 1) { + // Single last user; the resource can have a deallocation directly + // inserted as we have tracked both allocation and now deallocation to + // single code points. + LLVM_DEBUG({ + llvm::dbgs() << "[arc] inserting deallocation for "; + resource.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " after timepoint "; + lastTimepoint.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " directly\n"; + }); + auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( + builder, timepointsLoc, + builder.getType(), resource, + resourceSize, preferOrigin, lastTimepoint, resourceAffinity); + lastTimepoint.replaceAllUsesExcept(deallocaOp.getResultTimepoint(), + deallocaOp); + } else if (timepoints.size() > 1) { + // Multiple last users (fork); the resource still has a tracked + // allocation and deallocation but there are multiple code points where + // the deallocation may need to be inserted. + // + // Since this current analysis is local we can rely on SSA dominance to + // find the last SSA value and insert a join on all timepoints there to + // perform the deallocation, though this will cause extended lifetimes + // in cases where scheduled timeline operations complete out of order. + // We won't have correctness issues as all timepoints will be waited on. + LLVM_DEBUG({ + llvm::dbgs() << "[arc] inserting forked deallocation for "; + resource.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " after last SSA timepoint "; + lastTimepoint.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " as a join\n"; + }); + auto joinOp = IREE::Stream::TimepointJoinOp::create( + builder, timepointsLoc, + builder.getType(), + llvm::map_to_vector(timepoints, + [](Value timepoint) { return timepoint; })); + auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( + builder, timepointsLoc, + builder.getType(), resource, + resourceSize, preferOrigin, joinOp.getResultTimepoint(), + resourceAffinity); + lastTimepoint.replaceAllUsesExcept(deallocaOp.getResultTimepoint(), + joinOp); + } + }); +} + +// Collects all timepoint results from an operation and creates a join if there +// are multiple. Returns the single timepoint or joined timepoint, or nullopt if +// no timepoint results exist. +static std::optional getOrJoinTimepointResults(Operation *op) { + SmallVector resultTimepoints; + for (Value result : op->getResults()) { + if (isa(result.getType())) { + resultTimepoints.push_back(result); + } + } + + if (resultTimepoints.empty()) { + return std::nullopt; + } + + if (resultTimepoints.size() == 1) { + return resultTimepoints[0]; + } + + // Multiple timepoint results - create a join. + OpBuilder builder(op->getContext()); + builder.setInsertionPointAfter(op); + auto joinOp = IREE::Stream::TimepointJoinOp::create( + builder, op->getLoc(), builder.getType(), + resultTimepoints); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] created join of " << resultTimepoints.size() + << " timepoint results\n"; + }); + return joinOp.getResultTimepoint(); +} + +// Extends the lifetime of captured resources (defined above a region but used +// within it) to a specified result timepoint. +// +// NOTE: This is a conservative over-approximation. Resources captured in a +// control flow region have their lifetimes extended to the region's result +// timepoint even if they are only used in one branch (scf.if) or only in the +// first iteration (scf.for). More precise per-iteration/per-branch tracking +// would require backward dataflow analysis which is not implemented. +static void extendCapturedResourceLifetimes(Region ®ion, + Value resultTimepoint, + LastUseSet &lastUseSet, + AsmState *asmState) { + SetVector capturedValues; + getUsedValuesDefinedAbove(region, region, capturedValues); + for (Value captured : capturedValues) { + if (isa(captured.getType())) { + lastUseSet.consume(captured, resultTimepoint); + LLVM_DEBUG({ + llvm::dbgs() + << "[arc] captured resource lifetime extended to result\n"; + }); + } + } +} + +// Analyzes scf.for loop with captured resource tracking. +static bool analyzeForLoop(scf::ForOp forOp, AsmState *asmState, + LastUseSet &lastUseSet, + ScopedTimepointCoverage &coverage, + DenseSet &indeterminateResources, + DenseSet &handledResources) { + // scf.for doesn't implement TimelineOpInterface, but its result may be a + // timepoint. + std::optional loopResultTimepointOpt = + getOrJoinTimepointResults(forOp); + if (!loopResultTimepointOpt) { + // No timepoint result, cannot track lifetimes through this loop. + return false; + } + Value loopResultTimepoint = *loopResultTimepointOpt; + + // Register the loop result timepoint in the coverage map so that subsequent + // consume() calls can recognize and prune timepoints dominated by this one. + coverage.add(loopResultTimepoint); + + LLVM_DEBUG( + { llvm::dbgs() << "[arc] recognized scf.for with timepoint result\n"; }); + + // Step 1: Find captured resources (defined outside, used inside). + // These need their lifetimes extended to the loop result in the parent block. + extendCapturedResourceLifetimes(forOp.getRegion(), loopResultTimepoint, + lastUseSet, asmState); + + // Step 2: Recursively analyze the loop body for local allocations. + // Resources allocated and used entirely within the loop body should be + // deallocated inside the loop body, UNLESS they are yielded out. + for (Block &block : forOp.getRegion()) { + // Create a fresh LastUseSet for analysis within this block. + // This allows us to track and deallocate resources local to the loop body. + LastUseSet blockLastUseSet(asmState, coverage); + + LLVM_DEBUG(llvm::dbgs() << "[arc] analyzing loop body block\n"); + + // Recursively analyze operations in this block. + if (!analyzeBlockOps(block, asmState, blockLastUseSet, coverage, + indeterminateResources, handledResources)) { + LLVM_DEBUG(llvm::dbgs() << "[arc] failed to analyze loop body block\n"); + return false; + } + + // Check if any resources are yielded out of this block. + // Resources that escape via yield should NOT be deallocated locally. + // Instead, they must be registered in the parent scope so deallocations + // happen after the SCF operation completes. + auto *terminator = block.getTerminator(); + if (auto yieldOp = dyn_cast(terminator)) { + for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) { + if (isa(operand.getType())) { + // Mark as handled to prevent local deallocation. + handledResources.insert(operand); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] resource "; + operand.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() + << " yielded from loop body; preventing local deallocation\n"; + }); + + // Register the corresponding scf.for result in the parent scope. + // This ensures the resource gets deallocated after the loop even if + // the result is not used (dropped). + Value forResult = forOp.getResult(index); + lastUseSet.produce(forResult, loopResultTimepoint); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] registered loop result "; + forResult.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " in parent scope for deallocation\n"; + }); + } + } + } + + // Insert deallocations for resources local to this block (excluding yielded + // ones). + insertDeallocations(blockLastUseSet, asmState, indeterminateResources, + handledResources); + } + + return true; +} + +// Analyzes scf.if conditional with captured resource tracking. +static bool analyzeIfOp(scf::IfOp ifOp, AsmState *asmState, + LastUseSet &lastUseSet, + ScopedTimepointCoverage &coverage, + DenseSet &indeterminateResources, + DenseSet &handledResources) { + // scf.if doesn't implement TimelineOpInterface, but its result may be a + // timepoint. + std::optional ifResultTpOpt = getOrJoinTimepointResults(ifOp); + if (!ifResultTpOpt) { + // No timepoint result, cannot track lifetimes through this conditional. + return false; + } + Value ifResultTp = *ifResultTpOpt; + + // Register the if result timepoint in the coverage map. + coverage.add(ifResultTp); + + LLVM_DEBUG( + { llvm::dbgs() << "[arc] recognized scf.if with timepoint result\n"; }); + + // Step 1: Find captured resources in both branches. + // These need their lifetimes extended to the if result in the parent block. + extendCapturedResourceLifetimes(ifOp.getThenRegion(), ifResultTp, lastUseSet, + asmState); + if (!ifOp.getElseRegion().empty()) { + extendCapturedResourceLifetimes(ifOp.getElseRegion(), ifResultTp, + lastUseSet, asmState); + } + + // Step 2: Recursively analyze each branch for local allocations. + // Resources allocated and used entirely within a branch should be + // deallocated inside that branch, UNLESS they are yielded out. + // + // Track which if results have been registered to avoid duplicates. + // Both branches may yield resources that map to the same result index. + DenseSet registeredIfResults; + auto analyzeRegion = [&](Region ®ion) -> bool { + for (Block &block : region) { + // Create a fresh LastUseSet for analysis within this block. + LastUseSet blockLastUseSet(asmState, coverage); + + LLVM_DEBUG(llvm::dbgs() + << "[arc] analyzing conditional branch block\n"); + + // Recursively analyze operations in this block. + if (!analyzeBlockOps(block, asmState, blockLastUseSet, coverage, + indeterminateResources, handledResources)) { + LLVM_DEBUG(llvm::dbgs() + << "[arc] failed to analyze conditional branch block\n"); + return false; + } + + // Check if any resources are yielded out of this block. + // Resources that escape via yield should NOT be deallocated locally. + // Instead, they must be registered in the parent scope so deallocations + // happen after the SCF operation completes. + auto *terminator = block.getTerminator(); + if (auto yieldOp = dyn_cast(terminator)) { + for (auto [index, operand] : llvm::enumerate(yieldOp.getOperands())) { + if (isa(operand.getType())) { + // Mark as handled to prevent local deallocation. + handledResources.insert(operand); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] resource "; + operand.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " yielded from conditional branch; preventing " + "local deallocation\n"; + }); + + // Register the corresponding scf.if result in the parent scope. + // Skip if already registered (e.g., from the other branch). + Value ifResult = ifOp.getResult(index); + if (!registeredIfResults.contains(ifResult)) { + registeredIfResults.insert(ifResult); + lastUseSet.produce(ifResult, ifResultTp); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] registered if result "; + ifResult.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " in parent scope for deallocation\n"; + }); + } + } + } + } + + // Insert deallocations for resources local to this block (excluding + // yielded ones). + insertDeallocations(blockLastUseSet, asmState, indeterminateResources, + handledResources); + } + return true; + }; + + if (!analyzeRegion(ifOp.getThenRegion())) { + return false; + } + if (!ifOp.getElseRegion().empty() && !analyzeRegion(ifOp.getElseRegion())) { + return false; + } + + return true; +} + +// Analyzes scf.while loop with captured resource tracking. +static bool analyzeWhileOp(scf::WhileOp whileOp, AsmState *asmState, + LastUseSet &lastUseSet, + ScopedTimepointCoverage &coverage, + DenseSet &indeterminateResources, + DenseSet &handledResources) { + // scf.while has two regions: "before" (condition) and "after" (loop body). + std::optional whileResultTimepointOpt = + getOrJoinTimepointResults(whileOp); + if (!whileResultTimepointOpt) { + // No timepoint result, cannot track lifetimes through this loop. + return false; + } + Value whileResultTimepoint = *whileResultTimepointOpt; + + // Register the while result timepoint in the coverage map. + coverage.add(whileResultTimepoint); + + LLVM_DEBUG({ + llvm::dbgs() << "[arc] recognized scf.while with timepoint result\n"; + }); + + // Step 1: Find captured resources (defined outside, used inside either + // region). + extendCapturedResourceLifetimes(whileOp.getBefore(), whileResultTimepoint, + lastUseSet, asmState); + extendCapturedResourceLifetimes(whileOp.getAfter(), whileResultTimepoint, + lastUseSet, asmState); + + // Step 2: Recursively analyze both regions for local allocations. + // For scf.while, we need to analyze both "before" and "after" regions. + auto analyzeRegion = [&](Region ®ion) -> bool { + for (Block &block : region) { + LastUseSet blockLastUseSet(asmState, coverage); + + LLVM_DEBUG(llvm::dbgs() << "[arc] analyzing while loop region block\n"); + + if (!analyzeBlockOps(block, asmState, blockLastUseSet, coverage, + indeterminateResources, handledResources)) { + LLVM_DEBUG(llvm::dbgs() + << "[arc] failed to analyze while loop block\n"); + return false; + } + + // Check for yielded resources. + // Note: scf.while has TWO yield-like operations: + // - scf.condition in "before" region: args become while results + // - scf.yield in "after" region: args go back to "before" block args + auto *terminator = block.getTerminator(); + if (auto conditionOp = dyn_cast(terminator)) { + // "before" region ends with scf.condition - args become while results. + for (auto [index, operand] : llvm::enumerate(conditionOp.getArgs())) { + if (isa(operand.getType())) { + handledResources.insert(operand); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] resource "; + operand.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " yielded via scf.condition; " + "preventing local deallocation\n"; + }); + + // Register the corresponding scf.while result in the parent scope. + Value whileResult = whileOp.getResult(index); + lastUseSet.produce(whileResult, whileResultTimepoint); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] registered while result "; + whileResult.printAsOperand(llvm::dbgs(), *asmState); + llvm::dbgs() << " in parent scope for deallocation\n"; + }); + } + } + } else if (auto yieldOp = dyn_cast(terminator)) { + // "after" region ends with scf.yield. + // These values go back to the "before" region, NOT to while results. + // Mark as handled to prevent local deallocation (no parent produce). + for (Value operand : yieldOp.getOperands()) { + if (isa(operand.getType())) { + handledResources.insert(operand); + LLVM_DEBUG({ + llvm::dbgs() << "[arc] resource yielded from while body back " + "to loop; preventing local deallocation\n"; + }); + } + } + } + + insertDeallocations(blockLastUseSet, asmState, indeterminateResources, + handledResources); + } + return true; + }; + + if (!analyzeRegion(whileOp.getBefore())) { + return false; + } + if (!analyzeRegion(whileOp.getAfter())) { + return false; + } + + return true; +} + +// Dispatches to pattern-specific handlers for RegionBranchOpInterface ops. +static bool analyzeRegionBranchOp(RegionBranchOpInterface regionBranchOp, + AsmState *asmState, LastUseSet &lastUseSet, + ScopedTimepointCoverage &coverage, + DenseSet &indeterminateResources, + DenseSet &handledResources) { + Operation *op = regionBranchOp.getOperation(); + + // Dispatch to pattern-specific handlers using TypeSwitch. + return llvm::TypeSwitch(op) + .Case([&](scf::ForOp forOp) { + return analyzeForLoop(forOp, asmState, lastUseSet, coverage, + indeterminateResources, handledResources); + }) + .Case([&](scf::IfOp ifOp) { + return analyzeIfOp(ifOp, asmState, lastUseSet, coverage, + indeterminateResources, handledResources); + }) + .Case([&](scf::WhileOp whileOp) { + return analyzeWhileOp(whileOp, asmState, lastUseSet, coverage, + indeterminateResources, handledResources); + }) + .Default([](Operation *) { + // Unknown RegionBranchOpInterface - cannot analyze. + // This includes scf.parallel and scf.reduce which we don't use in IREE + // today. If needed in the future, add explicit handlers for them. + // Returning false triggers the conservative fallback in analyzeBlockOps + // which marks all resources in the operation as indeterminate. + // TODO(#12345): Add handlers for scf.parallel/scf.reduce if needed. + return false; + }); +} + static void processFuncOp(FunctionOpInterface funcOp) { // Today we bail on unstructured control flow. Eventually we want to be able // to only support structured control flow in stream but today it can @@ -374,7 +1015,7 @@ static void processFuncOp(FunctionOpInterface funcOp) { DenseSet handledResources; for (auto [blockIndex, block] : llvm::enumerate(funcOp.getBlocks())) { LLVM_DEBUG(llvm::dbgs() << "[arc] processing ^bb" << blockIndex << ":\n"); - LocalTimepointCoverage coverage(asmState.get()); + ScopedTimepointCoverage coverage(asmState.get()); LastUseSet lastUseSet(asmState.get(), coverage); // Add timepoint arguments as valid predecessors. @@ -410,291 +1051,34 @@ static void processFuncOp(FunctionOpInterface funcOp) { } } + // Check for CallOpInterface before analyzing - we cannot handle calls yet. + // TODO(benvanik): global analysis is required to know if calls are + // side-effecting. We could annotate the calls as LLVM does so that we + // could do local analysis and only pay attention to the operands/results. + // util.call supports tied operands but does not have a way to associate + // timepoints and this pass may never be able to work without that + // information. The most common case of calls today is in external modules + // not using `stream.cmd.call` and those are rare. for (auto &op : block) { - // TODO(benvanik): global analysis is required to know if calls are - // side-effecting. We could annotate the calls as LLVM does so that we - // could do local analysis and only pay attention to the operands/results. - // util.call supports tied operands but does not have a way to associate - // timepoints and this pass may never be able to work without that - // information. The most common case of calls today is in external modules - // not using `stream.cmd.call` and those are rare. if (isa(op)) { LLVM_DEBUG(llvm::dbgs() << "[arc] skipping function @" << getFuncName(funcOp) << " as it contains call ops (" << op.getName() << ")\n"); return; } + } - // TODO(benvanik): broader analysis or at least constrained handling of - // scf. For now we bail if any ops from control flow dialects are found. - // TODO(benvanik): maybe only skip the block containing the ops. - if (op.getDialect()->getNamespace() == "scf") { - LLVM_DEBUG(llvm::dbgs() - << "[arc] skipping function @" << getFuncName(funcOp) - << " block ^bb" << blockIndex - << " as it contains control flow ops (" << op.getName() - << ")\n"); - return; - } - - // Special case ops that are not timeline-aware but interoperate. - if (auto immediateOp = dyn_cast(op)) { - coverage.add(immediateOp.getResultTimepoint()); - continue; - } else if (auto importOp = - dyn_cast(op)) { - coverage.add(importOp.getResultTimepoint()); - continue; - } - - // Bail on processing any particular resource which has lifetime - // management ops. They should not have been inserted yet and their - // presence likely indicates we've already run the pass on this input. - // We continue analysis for subsequent ops that may use different - // resources so that we can handle other resources. - if (auto retainOp = dyn_cast(op)) { - Value handledResource = - lastUseSet.lookupResource(retainOp.getOperand()); - LLVM_DEBUG({ - llvm::dbgs() << "[arc] existing retain of "; - handledResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as handled\n"; - }); - handledResources.insert(handledResource); - continue; - } else if (auto releaseOp = - dyn_cast(op)) { - Value handledResource = - lastUseSet.lookupResource(releaseOp.getOperand()); - LLVM_DEBUG({ - llvm::dbgs() << "[arc] existing release of "; - handledResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as handled\n"; - }); - handledResources.insert(handledResource); - continue; - } - - // Analysis currently only works if all ops producing and consuming - // resources are timeline ops. Any resource accessed by non-timeline ops - // gets marked as indeterminate as the analysis does not know how they are - // used on the timeline. - auto timelineOp = dyn_cast(op); - if (!timelineOp) { - for (auto operand : op.getOperands()) { - if (isa(operand.getType())) { - Value baseResource = lastUseSet.lookupResource(operand); - LLVM_DEBUG({ - llvm::dbgs() << "[arc] non-timeline use of operand "; - baseResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as indeterminate\n"; - }); - indeterminateResources.insert(baseResource); - } - } - for (auto result : op.getResults()) { - if (isa(result.getType())) { - Value baseResource = lastUseSet.lookupResource(result); - LLVM_DEBUG({ - llvm::dbgs() << "[arc] non-timeline use of result "; - baseResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as indeterminate\n"; - }); - indeterminateResources.insert(baseResource); - } - } - continue; - } - - // Consumer-only timepoint ops (like stream.timepoint.await) block - // propagation. - Value resultTimepoint = timelineOp.getResultTimepoint(); - if (!resultTimepoint) { - LLVM_DEBUG({ - llvm::dbgs() << "[arc] terminating timeline use by " << op.getName() - << "; stopping propagation\n"; - }); - continue; - } - - // Populate coverage map for the declared timeline operation. - auto awaitTimepoints = timelineOp.getAwaitTimepoints(); - if (awaitTimepoints.empty()) { - coverage.add(resultTimepoint); - } else { - for (Value awaitTimepoint : timelineOp.getAwaitTimepoints()) { - coverage.add(awaitTimepoint, resultTimepoint); - } - } - - // Alloca ops may have been assigned as indeterminate when created. - if (auto allocaOp = dyn_cast(op)) { - if (allocaOp.getIndeterminateLifetime()) { - Value allocaResource = allocaOp.getResult(); - LLVM_DEBUG({ - llvm::dbgs() << "[arc] alloca producer explicitly states lifetime " - "is indeterminate for "; - allocaResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as indeterminate\n"; - }); - indeterminateResources.insert(allocaResource); - } - } - - // If a resource has a deallocation on it already then we cannot insert - // another. This can arise when the pass is run twice or when an earlier - // pass explicitly inserts deallocations to ensure they happen where they - // want instead of relying on this analysis. - if (auto deallocaOp = dyn_cast(op)) { - Value handledResource = - lastUseSet.lookupResource(deallocaOp.getOperand()); - LLVM_DEBUG({ - llvm::dbgs() << "[arc] existing deallocation of "; - handledResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as handled\n"; - }); - handledResources.insert(handledResource); - } - - // Track resources consumed/produced as part of the timeline operation. - // This gives us the last timepoint(s) using the resources (not the last - // SSA user). There may be multiple last timepoints if a fork occurs. - // - // Example: - // %r0, %t0 = alloca - // %t1 = exec %t0 => %r0 - // %t2 = exec %t0 => %r0 - // The last timepoints of %r0 would be %t1 and %t2, indicating that after - // %t1 and %t2 have both been reached (joined) the resource is no longer - // live. - auto tiedOp = dyn_cast(op); - for (auto operand : op.getOperands()) { - if (isa(operand.getType())) { - lastUseSet.consume(operand, timelineOp.getResultTimepoint()); - } - } - for (auto result : op.getResults()) { - if (isa(result.getType())) { - Value operand = - tiedOp ? tiedOp.getTiedResultOperand(result) : nullptr; - if (operand) { - lastUseSet.tie(operand, result, timelineOp.getResultTimepoint()); - } else { - lastUseSet.produce(result, timelineOp.getResultTimepoint()); - - // TODO(#20817): parameter loads (and potentially other ops) may - // cause far too many joins right now and will hit runtime errors - // and performance issues. Until we have a way to partition joins we - // need to avoid those. Parameter loads and other sources are things - // we likely want to handle with retain/release instead of - // deallocating anyway. Custom calls will also need a way to - // indicate whether they are alloca-like or not so we just exclude - // everything except alloca here. - if (!isa(op)) { - LLVM_DEBUG({ - llvm::dbgs() << "[arc] non-alloca producer " - << op.getName().getStringRef() << " of "; - result.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << "; marking as indeterminate (#20817)\n"; - }); - indeterminateResources.insert(result); - } - } - } - } + // Delegate block analysis to the shared helper. + if (!analyzeBlockOps(block, asmState.get(), lastUseSet, coverage, + indeterminateResources, handledResources)) { + LLVM_DEBUG(llvm::dbgs() << "[arc] failed to analyze function block\n"); + // The existing indeterminateResources/handledResources will prevent + // deallocations for unanalyzable parts. } // Insert deallocations for all resources that we successfully analyzed. - lastUseSet.forEachResource([&](Value resource, TimepointSet &timepoints) { - assert(!timepoints.empty() && "all resources should have a timepoint"); - - // Skip anything we could not analyze. - Value baseResource = lastUseSet.lookupResource(resource); - if (indeterminateResources.contains(baseResource)) { - LLVM_DEBUG({ - llvm::dbgs() << "[arc] skipping resource "; - baseResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << " marked as indeterminate\n"; - }); - return; - } else if (handledResources.contains(baseResource)) { - LLVM_DEBUG({ - llvm::dbgs() << "[arc] skipping resource "; - baseResource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << " marked as already handled\n"; - }); - return; - } - - // Finds the last timepoint in the set (if >1) in SSA dominance order. - Value lastTimepoint = getLastTimepointInBlock(timepoints); - assert(lastTimepoint && "must have at least one timepoint"); - OpBuilder builder(lastTimepoint.getContext()); - builder.setInsertionPointAfterValue(lastTimepoint); - - // Try to grab a resource size or insert a query. - // In almost all cases that this analysis can run we will have a - // size-aware op that provides it. - auto timepointsLoc = - getFusedLocFromTimepoints(resource.getLoc(), timepoints); - Value resourceSize = IREE::Util::SizeAwareTypeInterface::queryValueSize( - timepointsLoc, resource, builder); - - // Lookup the affinity of the resource. - // This should likely be a global affinity analysis but since we are - // currently only processing locally we can assume this is only used when - // we have an op local with an affinity assigned. - IREE::Stream::AffinityAttr resourceAffinity; - if (auto *definingOp = resource.getDefiningOp()) { - resourceAffinity = IREE::Stream::AffinityAttr::lookup(definingOp); - } - UnitAttr preferOrigin = - resourceAffinity ? UnitAttr{} : builder.getUnitAttr(); - - if (timepoints.size() == 1) { - // Single last user; the resource can have a deallocation directly - // inserted as we have tracked both allocation and now deallocation to - // single code points. - LLVM_DEBUG({ - llvm::dbgs() << "[arc] inserting deallocation for "; - resource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << " after timepoint "; - lastTimepoint.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << " directly\n"; - }); - auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( - builder, timepointsLoc, - builder.getType(), resource, - resourceSize, preferOrigin, lastTimepoint, resourceAffinity); - lastTimepoint.replaceAllUsesExcept(deallocaOp.getResultTimepoint(), - deallocaOp); - } else if (timepoints.size() > 1) { - // Multiple last users (fork); the resource still has a tracked - // allocation and deallocation but there are multiple code points where - // the deallocation may need to be inserted. - // - // Since this current analysis is local we can rely on SSA dominance to - // find the last SSA value and insert a join on all timepoints there to - // perform the deallocation, though this will cause extended lifetimes - // in cases where scheduled timeline operations complete out of order. - // We won't have correctness issues as all timepoints will be waited on. - LLVM_DEBUG({ - llvm::dbgs() << "[arc] inserting forked deallocation for "; - resource.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << " after last SSA timepoint "; - lastTimepoint.printAsOperand(llvm::dbgs(), *asmState); - llvm::dbgs() << " as a join\n"; - }); - Value joinedTimepoint = IREE::Stream::TimepointJoinOp::join( - timepointsLoc, llvm::to_vector(timepoints), builder); - auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( - builder, timepointsLoc, - builder.getType(), resource, - resourceSize, preferOrigin, joinedTimepoint, resourceAffinity); - lastTimepoint.replaceAllUsesExcept(deallocaOp.getResultTimepoint(), - joinedTimepoint.getDefiningOp()); - } - }); + insertDeallocations(lastUseSet, asmState.get(), indeterminateResources, + handledResources); } LLVM_DEBUG(llvm::dbgs() << "\n"); diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel index 40f34e3119b6..c36bb57955ce 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/BUILD.bazel @@ -21,6 +21,7 @@ iree_lit_test_suite( "annotate_dispatch_arguments.mlir", "annotate_dispatch_assumptions.mlir", "automatic_reference_counting.mlir", + "automatic_reference_counting_scf.mlir", "clone_to_consumers.mlir", "convert_to_stream.mlir", "dump_statistics.mlir", diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt index d6f192d59718..09c45c587c0a 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/CMakeLists.txt @@ -19,6 +19,7 @@ iree_lit_test_suite( "annotate_dispatch_arguments.mlir" "annotate_dispatch_assumptions.mlir" "automatic_reference_counting.mlir" + "automatic_reference_counting_scf.mlir" "clone_to_consumers.mlir" "convert_to_stream.mlir" "dump_statistics.mlir" diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting.mlir index 10799cfd5f08..2215d9b4b8d7 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting.mlir @@ -31,9 +31,9 @@ util.func private @insertDeallocaWithAffinity(%input_timepoint: !stream.timepoin util.func private @insertDeallocaOneUserOneUse(%input_timepoint: !stream.timepoint, %size: index) -> !stream.timepoint { // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca %resource, %alloca_timepoint = stream.resource.alloca uninitialized await(%input_timepoint) => !stream.resource{%size} => !stream.timepoint - // CHECK: %[[EXECUTE_TIMEPOINT:.+]] = stream.cmd.execute - %execute_timepoint = stream.cmd.execute await(%alloca_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: %[[EXECUTE_TIMEPOINT:.+]] = stream.test.timeline_op + %execute_timepoint = stream.test.timeline_op await(%alloca_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[EXECUTE_TIMEPOINT]]) => %[[RESOURCE]] // CHECK: util.return %[[DEALLOCA_TIMEPOINT]] util.return %execute_timepoint : !stream.timepoint @@ -49,9 +49,9 @@ util.func private @insertDeallocaOneUserOneUse(%input_timepoint: !stream.timepoi util.func private @insertDeallocaOneUserMultiUse(%input_timepoint: !stream.timepoint, %size: index) -> !stream.timepoint { // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca %resource, %alloca_timepoint = stream.resource.alloca uninitialized await(%input_timepoint) => !stream.resource{%size} => !stream.timepoint - // CHECK: %[[EXECUTE_TIMEPOINT:.+]] = stream.cmd.execute - %execute_timepoint = stream.cmd.execute await(%alloca_timepoint) => with(%resource as %capture0 : !stream.resource{%size}, %resource as %capture1 : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: %[[EXECUTE_TIMEPOINT:.+]] = stream.test.timeline_op + %execute_timepoint = stream.test.timeline_op await(%alloca_timepoint) => + with(%resource, %resource) : (!stream.resource{%size}, !stream.resource{%size}) -> () => !stream.timepoint // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[EXECUTE_TIMEPOINT]]) => %[[RESOURCE]] // CHECK-NOT: stream.resource.dealloca // CHECK: util.return %[[DEALLOCA_TIMEPOINT]] @@ -67,12 +67,12 @@ util.func private @insertDeallocaOneUserMultiUse(%input_timepoint: !stream.timep util.func private @insertDeallocaMultiUserSequence(%input_timepoint: !stream.timepoint, %size: index) -> !stream.timepoint { // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca %resource, %alloca_timepoint = stream.resource.alloca uninitialized await(%input_timepoint) => !stream.resource{%size} => !stream.timepoint - // CHECK: %[[EXECUTE0_TIMEPOINT:.+]] = stream.cmd.execute - %execute0_timepoint = stream.cmd.execute await(%alloca_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint - // CHECK: %[[EXECUTE1_TIMEPOINT:.+]] = stream.cmd.execute - %execute1_timepoint = stream.cmd.execute await(%execute0_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: %[[EXECUTE0_TIMEPOINT:.+]] = stream.test.timeline_op + %execute0_timepoint = stream.test.timeline_op await(%alloca_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + // CHECK: %[[EXECUTE1_TIMEPOINT:.+]] = stream.test.timeline_op + %execute1_timepoint = stream.test.timeline_op await(%execute0_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[EXECUTE1_TIMEPOINT]]) => %[[RESOURCE]] // Note: needs cleanup in ElideTimepointsPass. // CHECK: %[[EXECUTE_JOIN_TIMEPOINT:.+]] = stream.timepoint.join max(%[[EXECUTE0_TIMEPOINT]], %[[DEALLOCA_TIMEPOINT]]) @@ -90,17 +90,17 @@ util.func private @insertDeallocaMultiUserSequence(%input_timepoint: !stream.tim util.func private @insertDeallocaMultiUserFork(%input_timepoint: !stream.timepoint, %size: index) -> (!stream.timepoint, !stream.timepoint) { // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca %resource, %alloca_timepoint = stream.resource.alloca uninitialized await(%input_timepoint) => !stream.resource{%size} => !stream.timepoint - // CHECK: %[[EXECUTE0_TIMEPOINT:.+]] = stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) - %execute0_timepoint = stream.cmd.execute await(%alloca_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: %[[EXECUTE0_TIMEPOINT:.+]] = stream.test.timeline_op await(%[[ALLOCA_TIMEPOINT]]) + %execute0_timepoint = stream.test.timeline_op await(%alloca_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // Note: this is here to force another timepoint user earlier than the last // deallocation; this exposes potential SSA ordering issues. - // CHECK: %[[OTHER_TIMEPOINT:.+]] = stream.cmd.execute await(%[[EXECUTE0_TIMEPOINT]]) - %other_timepoint = stream.cmd.execute await(%execute0_timepoint) => with() { - } => !stream.timepoint - // CHECK: %[[EXECUTE1_TIMEPOINT:.+]] = stream.cmd.execute await(%[[ALLOCA_TIMEPOINT]]) - %execute1_timepoint = stream.cmd.execute await(%alloca_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: %[[OTHER_TIMEPOINT:.+]] = stream.test.timeline_op await(%[[EXECUTE0_TIMEPOINT]]) + %other_timepoint = stream.test.timeline_op await(%execute0_timepoint) => + with() : () -> () => !stream.timepoint + // CHECK: %[[EXECUTE1_TIMEPOINT:.+]] = stream.test.timeline_op await(%[[ALLOCA_TIMEPOINT]]) + %execute1_timepoint = stream.test.timeline_op await(%alloca_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // CHECK: %[[DEALLOCA_JOIN_TIMEPOINT:.+]] = stream.timepoint.join max(%[[EXECUTE0_TIMEPOINT]], %[[EXECUTE1_TIMEPOINT]]) // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca origin await(%[[DEALLOCA_JOIN_TIMEPOINT]]) => %[[RESOURCE]] // Note: the dealloca adds an additional synchronization point. @@ -117,9 +117,9 @@ util.func private @insertDeallocaMultiUserFork(%input_timepoint: !stream.timepoi util.func private @ignoreHandledResources(%input_timepoint: !stream.timepoint, %size: index) -> !stream.timepoint { // CHECK: stream.resource.alloca %resource, %alloca_timepoint = stream.resource.alloca uninitialized await(%input_timepoint) => !stream.resource{%size} => !stream.timepoint - // CHECK: stream.cmd.execute - %execute_timepoint = stream.cmd.execute await(%alloca_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: stream.test.timeline_op + %execute_timepoint = stream.test.timeline_op await(%alloca_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // CHECK: %[[DEALLOCA_TIMEPOINT:.+]] = stream.resource.dealloca // CHECK-NOT: stream.resource.dealloca %dealloca_timepoint = stream.resource.dealloca await(%execute_timepoint) => %resource : !stream.resource{%size} => !stream.timepoint @@ -133,8 +133,8 @@ util.func private @ignoreHandledResources(%input_timepoint: !stream.timepoint, % // CHECK-LABEL: @ignoreLiveIn util.func private @ignoreLiveIn(%input_timepoint: !stream.timepoint, %resource: !stream.resource, %size: index) -> !stream.timepoint { - %execute_timepoint = stream.cmd.execute await(%input_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + %execute_timepoint = stream.test.timeline_op await(%input_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // CHECK-NOT: stream.resource.dealloca util.return %execute_timepoint : !stream.timepoint } @@ -238,23 +238,6 @@ util.func private @some_func() -> () { // ----- -// TODO(benvanik): scf is something we should support even in this local pass -// in constrained scenarios (only for resources not used in regions, etc). -// For now using SCF will cause the entire parent block to be skipped. - -// CHECK-LABEL: @ignoreSCF -util.func private @ignoreSCF(%input_timepoint: !stream.timepoint, %size: index) -> !stream.timepoint { - %resource, %alloca_timepoint = stream.resource.alloca uninitialized on(#hal.device.promise<@device>) await(%input_timepoint) => !stream.resource{%size} => !stream.timepoint - // CHECK-NOT: stream.resource.dealloca - %cond = arith.constant 1 : i1 - scf.if %cond { - scf.yield - } - util.return %alloca_timepoint : !stream.timepoint -} - -// ----- - // Tests that resources loaded from globals are treated as indeterminate. util.global private @resource : !stream.resource @@ -264,9 +247,9 @@ util.global private @timepoint : !stream.timepoint util.func private @ignoreGlobalLoad(%size: index) -> !stream.timepoint { %resource = util.global.load @resource : !stream.resource %load_timepoint = util.global.load @timepoint : !stream.timepoint - // CHECK: stream.cmd.execute - %execute_timepoint = stream.cmd.execute await(%load_timepoint) => with(%resource as %capture : !stream.resource{%size}) { - } => !stream.timepoint + // CHECK: stream.test.timeline_op + %execute_timepoint = stream.test.timeline_op await(%load_timepoint) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint // CHECK-NOT: stream.resource.dealloca util.return %execute_timepoint : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting_scf.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting_scf.mlir new file mode 100644 index 000000000000..4bb645804167 --- /dev/null +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/automatic_reference_counting_scf.mlir @@ -0,0 +1,779 @@ +// RUN: iree-opt --split-input-file --iree-stream-automatic-reference-counting %s | FileCheck %s + +// Tests that resources allocated outside a loop and captured inside have their +// lifetime extended (NOT marked indeterminate). + +// CHECK-LABEL: @loop_captured_resource +// CHECK-SAME: (%[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @loop_captured_resource(%input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // Loop captures and uses resource (cmd-level pattern). + // CHECK: %[[LOOP_RESULT:.+]] = scf.for + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg = %alloca_tp) -> !stream.timepoint { + // CHECK: stream.test.timeline_op await(%{{.+}}) + %cmd_tp = stream.test.timeline_op await(%arg) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %cmd_tp : !stream.timepoint + } + + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[LOOP_RESULT]]) + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %loop_result : !stream.timepoint +} + +// ----- + +// Tests that resources allocated INSIDE a loop that never escape can be +// deallocated inside the loop body (local lifetime). + +// CHECK-LABEL: @loop_local_resource +util.func private @loop_local_resource(%input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: scf.for + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg = %input_tp) -> !stream.timepoint { + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%arg) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op + %cmd_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[LOCAL_ALLOCA_TP]], %[[CMD_TP]]) + // CHECK: %[[LOCAL_DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[LOCAL_RESOURCE]] + // CHECK: scf.yield %[[LOCAL_DEALLOCA_TP]] + // Local resource deallocated inside loop body (never escapes). + scf.yield %cmd_tp : !stream.timepoint + } + + util.return %loop_result : !stream.timepoint +} + +// ----- + +// Tests scf.if with captured resource in both branches. + +// CHECK-LABEL: @if_captured_resource +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @if_captured_resource(%cond: i1, %input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c1 = arith.constant 1 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[IF_RESULT:.+]] = scf.if + %if_result = scf.if %cond -> !stream.timepoint { + // CHECK: stream.test.timeline_op await(%[[ALLOCA_TP]]) + %then_tp = stream.test.timeline_op await(%alloca_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %then_tp : !stream.timepoint + } else { + // CHECK: stream.test.timeline_op await(%[[ALLOCA_TP]]) + %else_tp = stream.test.timeline_op await(%alloca_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %else_tp : !stream.timepoint + } + + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[IF_RESULT]]) + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %if_result : !stream.timepoint +} + +// ----- + +// Tests scf.if with local resource in then-branch that doesn't escape. + +// CHECK-LABEL: @if_local_resource +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @if_local_resource(%cond: i1, %input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c1 = arith.constant 1 : index + + // CHECK: scf.if + %if_result = scf.if %cond -> !stream.timepoint { + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[THEN_TP:.+]] = stream.test.timeline_op await(%[[LOCAL_ALLOCA_TP]]) + %then_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // CHECK: %[[LOCAL_DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[THEN_TP]]) => %[[LOCAL_RESOURCE]] + // CHECK: scf.yield %[[LOCAL_DEALLOCA_TP]] + // Local resource deallocated inside then-branch (coverage analysis eliminates redundant join). + scf.yield %then_tp : !stream.timepoint + } else { + scf.yield %input_tp : !stream.timepoint + } + + util.return %if_result : !stream.timepoint +} + +// ----- + +// Tests nested control flow: scf.if inside scf.for with captured resource. + +// CHECK-LABEL: @nested_if_in_loop +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @nested_if_in_loop(%cond: i1, %input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[LOOP_RESULT:.+]] = scf.for + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg = %alloca_tp) -> !stream.timepoint { + // CHECK: scf.if + %if_tp = scf.if %cond -> !stream.timepoint { + // CHECK: stream.test.timeline_op + %then_tp = stream.test.timeline_op await(%arg) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %then_tp : !stream.timepoint + } else { + // CHECK: stream.test.timeline_op + %else_tp = stream.test.timeline_op await(%arg) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %else_tp : !stream.timepoint + } + scf.yield %if_tp : !stream.timepoint + } + + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[LOOP_RESULT]]) + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // Captured resource through nested if-in-loop should NOT be indeterminate. + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %loop_result : !stream.timepoint +} + +// ----- + +// Tests nested control flow: scf.for inside scf.if with captured resource. + +// CHECK-LABEL: @nested_loop_in_if +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @nested_loop_in_if(%cond: i1, %input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[IF_RESULT:.+]] = scf.if + %if_result = scf.if %cond -> !stream.timepoint { + // CHECK: scf.for + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg = %alloca_tp) -> !stream.timepoint { + // CHECK: stream.test.timeline_op + %cmd_tp = stream.test.timeline_op await(%arg) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %cmd_tp : !stream.timepoint + } + scf.yield %loop_result : !stream.timepoint + } else { + scf.yield %alloca_tp : !stream.timepoint + } + + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[IF_RESULT]]) + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // Captured resource through nested loop-in-if should NOT be indeterminate. + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %if_result : !stream.timepoint +} + +// ----- + +// Tests multiple captured resources in a loop. + +// CHECK-LABEL: @loop_multiple_captured +// CHECK-SAME: (%[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE1:.+]]: index, %[[SIZE2:.+]]: index) +util.func private @loop_multiple_captured(%input_tp: !stream.timepoint, %size1: index, %size2: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE1:.+]], %[[ALLOCA_TP1:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE1]]} + %resource1, %alloca_tp1 = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size1} => !stream.timepoint + + // CHECK: %[[RESOURCE2:.+]], %[[ALLOCA_TP2:.+]] = stream.resource.alloca uninitialized await(%[[ALLOCA_TP1]]) => !stream.resource{%[[SIZE2]]} + %resource2, %alloca_tp2 = stream.resource.alloca uninitialized await(%alloca_tp1) => !stream.resource{%size2} => !stream.timepoint + + // CHECK: %[[LOOP_RESULT:.+]] = scf.for + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg = %alloca_tp2) -> !stream.timepoint { + // CHECK: stream.test.timeline_op + %cmd_tp = stream.test.timeline_op await(%arg) => + with(%resource1, %resource2) : (!stream.resource{%size1}, !stream.resource{%size2}) -> () => !stream.timepoint + scf.yield %cmd_tp : !stream.timepoint + } + + // CHECK: %[[JOIN1:.+]] = stream.timepoint.join max(%[[ALLOCA_TP1]], %[[LOOP_RESULT]]) + // CHECK: %[[DEALLOCA1:.+]] = stream.resource.dealloca origin await(%[[JOIN1]]) => %[[RESOURCE1]] + // CHECK: %[[JOIN2:.+]] = stream.timepoint.join max(%[[ALLOCA_TP2]], %[[DEALLOCA1]]) + // CHECK: %[[DEALLOCA2:.+]] = stream.resource.dealloca origin await(%[[JOIN2]]) => %[[RESOURCE2]] + // Both captured resources should NOT be indeterminate. + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA2]] + util.return %loop_result : !stream.timepoint +} + +// ----- + +// Tests scf.for with iter_args carrying a resource (rare case). + +// CHECK-LABEL: @loop_iter_args_resource +util.func private @loop_iter_args_resource(%input_tp: !stream.timepoint, %initial_resource: !stream.resource, %size: index) -> (!stream.resource, !stream.timepoint) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESULT:.+]]:2 = scf.for + %result_resource, %result_tp = scf.for %i = %c0 to %c10 step %c1 + iter_args(%iter_resource = %initial_resource, %iter_tp = %input_tp) -> (!stream.resource, !stream.timepoint) { + // CHECK: stream.test.timeline_op + %cmd_tp = stream.test.timeline_op await(%iter_tp) => + with(%iter_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %iter_resource, %cmd_tp : !stream.resource, !stream.timepoint + } + + // Loop-carried resource via iter_args should be aliased correctly. + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[RESULT]]#0, %[[RESULT]]#1 + util.return %result_resource, %result_tp : !stream.resource, !stream.timepoint +} + +// ----- + +// Tests deeply nested control flow (3 levels). + +// CHECK-LABEL: @deeply_nested +// CHECK-SAME: ({{.+}}: i1, {{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @deeply_nested(%cond1: i1, %cond2: i1, %input_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c5 = arith.constant 5 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[IF1_RESULT:.+]] = scf.if + %if1_result = scf.if %cond1 -> !stream.timepoint { + // CHECK: scf.for + %loop_result = scf.for %i = %c0 to %c5 step %c1 iter_args(%arg = %alloca_tp) -> !stream.timepoint { + // CHECK: scf.if + %if2_result = scf.if %cond2 -> !stream.timepoint { + // CHECK: stream.test.timeline_op + %cmd_tp = stream.test.timeline_op await(%arg) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %cmd_tp : !stream.timepoint + } else { + scf.yield %arg : !stream.timepoint + } + scf.yield %if2_result : !stream.timepoint + } + scf.yield %loop_result : !stream.timepoint + } else { + scf.yield %alloca_tp : !stream.timepoint + } + + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[IF1_RESULT]]) + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // Deeply nested captured resource should NOT be indeterminate. + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %if1_result : !stream.timepoint +} + +// ----- + +// Tests that a resource allocated INSIDE a loop and yielded OUT should NOT be +// deallocated inside the loop body (use-after-free bug fix). + +// CHECK-LABEL: @loop_local_resource_yielded +// CHECK-SAME: ({{.+}}: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: !stream.resource) +util.func private @loop_local_resource_yielded(%input_tp: !stream.timepoint, %size: index, %init_resource: !stream.resource) -> (!stream.resource, !stream.timepoint) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[LOOP_RESULT:.+]]:2 = scf.for + %loop_resource, %loop_tp = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg_res = %init_resource, %arg_tp = %input_tp) -> (!stream.resource, !stream.timepoint) { + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%{{.+}}) => !stream.resource{%[[SIZE]]} + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%arg_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op await(%[[LOCAL_ALLOCA_TP]]) + %cmd_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // Resource is yielded out - should NOT be deallocated inside loop. + // CHECK-NOT: stream.resource.dealloca + // CHECK: scf.yield %[[LOCAL_RESOURCE]], %[[CMD_TP]] + scf.yield %local_resource, %cmd_tp : !stream.resource, !stream.timepoint + } + + // The yielded resource should be available here for use. + // CHECK: util.return %[[LOOP_RESULT]]#0, %[[LOOP_RESULT]]#1 + util.return %loop_resource, %loop_tp : !stream.resource, !stream.timepoint +} + +// ----- + +// Tests that a resource allocated INSIDE an if-branch and yielded OUT should +// NOT be deallocated inside the branch (use-after-free bug fix). + +// CHECK-LABEL: @if_local_resource_yielded +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: !stream.resource) +util.func private @if_local_resource_yielded(%cond: i1, %input_tp: !stream.timepoint, %size: index, %else_resource: !stream.resource) -> (!stream.resource, !stream.timepoint) { + // CHECK: %[[IF_RESULT:.+]]:2 = scf.if + %if_resource, %if_tp = scf.if %cond -> (!stream.resource, !stream.timepoint) { + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op await(%[[LOCAL_ALLOCA_TP]]) + %cmd_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // Resource is yielded out - should NOT be deallocated inside branch. + // CHECK-NOT: stream.resource.dealloca + // CHECK: scf.yield %[[LOCAL_RESOURCE]], %[[CMD_TP]] + scf.yield %local_resource, %cmd_tp : !stream.resource, !stream.timepoint + } else { + // Else branch yields a different resource. + scf.yield %else_resource, %input_tp : !stream.resource, !stream.timepoint + } + + // The yielded resource should be available here for use. + // CHECK: util.return %[[IF_RESULT]]#0, %[[IF_RESULT]]#1 + util.return %if_resource, %if_tp : !stream.resource, !stream.timepoint +} + +// ----- + +// Tests that when scf.for returns MULTIPLE timepoints, the pass creates a join +// and uses it for tracking captured resource lifetimes. + +// CHECK-LABEL: @loop_multiple_timepoint_results +// CHECK-SAME: (%[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @loop_multiple_timepoint_results(%input_tp: !stream.timepoint, %size: index) -> (!stream.timepoint, !stream.timepoint) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // Loop returns TWO timepoints and captures a resource. + // CHECK: %[[LOOP_RESULTS:.+]]:2 = scf.for + %loop_tp1, %loop_tp2 = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg1 = %alloca_tp, %arg2 = %alloca_tp) -> (!stream.timepoint, !stream.timepoint) { + // CHECK: stream.test.timeline_op await(%{{.+}}) + %cmd_tp1 = stream.test.timeline_op await(%arg1) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // Second command also uses captured resource. + // CHECK: stream.test.timeline_op await(%{{.+}}) + %cmd_tp2 = stream.test.timeline_op await(%arg2) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + scf.yield %cmd_tp1, %cmd_tp2 : !stream.timepoint, !stream.timepoint + } + + // The pass should create a JOIN of the two loop result timepoints. + // CHECK: %[[LOOP_JOIN:.+]] = stream.timepoint.join max(%[[LOOP_RESULTS]]#0, %[[LOOP_RESULTS]]#1) + + // The captured resource needs to await BOTH the alloca and loop execution. + // The pass creates another join combining alloca_tp with the loop join. + // CHECK: %[[FINAL_JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[LOOP_JOIN]]) + + // The captured resource is deallocated awaiting the final join. + // CHECK: stream.resource.dealloca origin await(%[[FINAL_JOIN]]) => %[[RESOURCE]] + + // CHECK: util.return %[[LOOP_RESULTS]]#0, %[[LOOP_RESULTS]]#1 + util.return %loop_tp1, %loop_tp2 : !stream.timepoint, !stream.timepoint +} + +// ----- + +// Tests scf.while with captured resource tracking. + +// CHECK-LABEL: @while_captured_resource +// CHECK-SAME: (%[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: index) +util.func private @while_captured_resource(%input_tp: !stream.timepoint, %size: index, %bound: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[WHILE_RESULT:.+]]:2 = scf.while + %while_result:2 = scf.while (%iter = %c0, %tp = %alloca_tp) : (index, !stream.timepoint) -> (index, !stream.timepoint) { + %cond = arith.cmpi slt, %iter, %bound : index + scf.condition(%cond) %iter, %tp : index, !stream.timepoint + } do { + ^bb0(%iter: index, %tp: !stream.timepoint): + // CHECK: stream.test.timeline_op await(%{{.+}}) + %cmd_tp = stream.test.timeline_op await(%tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + %next_iter = arith.addi %iter, %c1 : index + scf.yield %next_iter, %cmd_tp : index, !stream.timepoint + } + + // The captured resource needs to await both alloca and while execution. + // The pass creates a join of the alloca timepoint and while result timepoint. + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[WHILE_RESULT]]#1) + + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %while_result#1 : !stream.timepoint +} + +// ----- + +// Tests that a resource yielded from one branch but not another is correctly +// handled (resource available from both branches, but only allocated in one). + +// CHECK-LABEL: @if_resource_from_one_branch +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: !stream.resource) +util.func private @if_resource_from_one_branch(%cond: i1, %input_tp: !stream.timepoint, %size: index, %fallback_resource: !stream.resource) -> (!stream.resource, !stream.timepoint) { + // CHECK: %[[IF_RESULT:.+]]:2 = scf.if + %if_resource, %if_tp = scf.if %cond -> (!stream.resource, !stream.timepoint) { + // Then-branch allocates a new resource and yields it. + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op await(%[[LOCAL_ALLOCA_TP]]) + %cmd_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // Resource yielded from then-branch - should NOT be deallocated. + // CHECK-NOT: stream.resource.dealloca + // CHECK: scf.yield %[[LOCAL_RESOURCE]], %[[CMD_TP]] + scf.yield %local_resource, %cmd_tp : !stream.resource, !stream.timepoint + } else { + // Else-branch yields the fallback resource (defined outside). + scf.yield %fallback_resource, %input_tp : !stream.resource, !stream.timepoint + } + + // The yielded resource should be available for use. + // CHECK: util.return %[[IF_RESULT]]#0, %[[IF_RESULT]]#1 + util.return %if_resource, %if_tp : !stream.resource, !stream.timepoint +} + +// ----- + +// Tests deeply nested SCF operations (scf.if inside scf.while inside scf.for). + +// CHECK-LABEL: @deeply_nested_scf +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: index) +util.func private @deeply_nested_scf(%cond: i1, %input_tp: !stream.timepoint, %size: index, %bound: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // Outer loop: scf.for + // CHECK: %[[FOR_RESULT:.+]] = scf.for + %for_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg_tp = %alloca_tp) -> !stream.timepoint { + // Middle loop: scf.while + // CHECK: %[[WHILE_RESULT:.+]]:2 = scf.while + %while_result:2 = scf.while (%iter = %c0, %tp = %arg_tp) : (index, !stream.timepoint) -> (index, !stream.timepoint) { + %cond_check = arith.cmpi slt, %iter, %bound : index + scf.condition(%cond_check) %iter, %tp : index, !stream.timepoint + } do { + ^bb0(%iter: index, %tp: !stream.timepoint): + // Inner conditional: scf.if + // CHECK: %[[IF_RESULT:.+]] = scf.if + %if_result = scf.if %cond -> !stream.timepoint { + // CHECK: stream.test.timeline_op await(%{{.+}}) + %cmd_tp = stream.test.timeline_op await(%tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %cmd_tp : !stream.timepoint + } else { + scf.yield %tp : !stream.timepoint + } + // CHECK: %[[NEXT_ITER:.+]] = arith.addi + %next_iter = arith.addi %iter, %c1 : index + // CHECK: scf.yield %[[NEXT_ITER]], %[[IF_RESULT]] + scf.yield %next_iter, %if_result : index, !stream.timepoint + } + // CHECK: scf.yield %[[WHILE_RESULT]]#1 + scf.yield %while_result#1 : !stream.timepoint + } + + // Resource captured through 3 levels of nesting should be tracked correctly. + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[FOR_RESULT]]) + // CHECK: %[[DEALLOCA_TP:.+]] = stream.resource.dealloca origin await(%[[JOIN]]) => %[[RESOURCE]] + // CHECK-NOT: marked indeterminate + // CHECK: util.return %[[DEALLOCA_TP]] + util.return %for_result : !stream.timepoint +} + +// ----- + +// Tests that timepoint coverage correctly spans parent and nested scf.for regions. +// A loop body that joins a parent-scope timepoint with an iter_arg timepoint +// requires the coverage analysis to track timepoints across scope boundaries. +// This guards against incorrectly localizing coverage per-block. + +// CHECK-LABEL: @cross_scope_for_await_parent +// CHECK-SAME: (%[[PARENT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @cross_scope_for_await_parent(%parent_tp: !stream.timepoint, %size: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // Allocate resource in parent scope. + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%parent_tp) => !stream.resource{%size} => !stream.timepoint + + // Nested loop awaits on parent_tp and alloca_tp. + // If coverage were per-block, covers(parent_tp, nested_tp) would fail. + // CHECK: %[[LOOP_RESULT:.+]] = scf.for %{{.+}} = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ITER:.+]] = %[[ALLOCA_TP]]) + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg = %alloca_tp) -> !stream.timepoint { + // Join parent timepoint with iter_arg - tests cross-scope coverage tracking. + // CHECK: %[[JOINED:.+]] = stream.timepoint.join max(%[[PARENT_TP]], %[[ITER]]) + %joined_tp = stream.timepoint.join max(%parent_tp, %arg) => !stream.timepoint + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op await(%[[JOINED]]) => with(%[[RESOURCE]]) : (!stream.resource{%[[SIZE]]}) + %cmd_tp = stream.test.timeline_op await(%joined_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + // CHECK: scf.yield %[[CMD_TP]] + scf.yield %cmd_tp : !stream.timepoint + } + + // Coverage must correctly track that parent_tp is covered by loop_result. + // CHECK: %[[JOINED_TP:.+]] = stream.timepoint.join max(%[[ALLOCA_TP]], %[[LOOP_RESULT]]) + // CHECK: %[[DEALLOCA:.+]] = stream.resource.dealloca origin await(%[[JOINED_TP]]) => %[[RESOURCE]] + // CHECK: util.return %[[DEALLOCA]] + util.return %loop_result : !stream.timepoint +} + +// ----- + +// Tests that timepoint coverage correctly spans parent and nested scf.if regions. +// Both if-branches join a parent-scope timepoint with an alloca timepoint, +// requiring the coverage analysis to track timepoints across scope boundaries. + +// CHECK-LABEL: @cross_scope_if_await_parent +// CHECK-SAME: (%[[PARENT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: i1) +util.func private @cross_scope_if_await_parent(%parent_tp: !stream.timepoint, %size: index, %cond: i1) -> !stream.timepoint { + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%parent_tp) => !stream.resource{%size} => !stream.timepoint + + // Nested if awaits on parent_tp. + // CHECK: %[[IF_RESULT:.+]] = scf.if + %if_result = scf.if %cond -> !stream.timepoint { + // Then branch joins parent with alloca timepoint. + // CHECK: %[[THEN_JOINED:.+]] = stream.timepoint.join max(%[[PARENT_TP]], %[[ALLOCA_TP]]) + %then_joined_tp = stream.timepoint.join max(%parent_tp, %alloca_tp) => !stream.timepoint + // CHECK: %[[THEN_TP:.+]] = stream.test.timeline_op await(%[[THEN_JOINED]]) => with(%[[RESOURCE]]) : (!stream.resource{%[[SIZE]]}) + %then_tp = stream.test.timeline_op await(%then_joined_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + // CHECK: scf.yield %[[THEN_TP]] + scf.yield %then_tp : !stream.timepoint + } else { + // Else branch joins parent with alloca timepoint. + // CHECK: %[[ELSE_JOINED:.+]] = stream.timepoint.join max(%[[PARENT_TP]], %[[ALLOCA_TP]]) + %else_joined_tp = stream.timepoint.join max(%parent_tp, %alloca_tp) => !stream.timepoint + // CHECK: %[[ELSE_TP:.+]] = stream.test.timeline_op await(%[[ELSE_JOINED]]) => with(%[[RESOURCE]]) : (!stream.resource{%[[SIZE]]}) + %else_tp = stream.test.timeline_op await(%else_joined_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + // CHECK: scf.yield %[[ELSE_TP]] + scf.yield %else_tp : !stream.timepoint + } + + // Pass creates a final join combining alloca and if result. + // CHECK: %{{.+}} = stream.timepoint.join max(%[[ALLOCA_TP]], %[[IF_RESULT]]) + // CHECK: %[[DEALLOCA:.+]] = stream.resource.dealloca origin await(%{{.+}}) => %[[RESOURCE]] + // CHECK: util.return %[[DEALLOCA]] + util.return %if_result : !stream.timepoint +} + +// ----- + +// Tests that timepoint coverage correctly spans parent and nested scf.while regions. +// The while loop body joins a parent-scope timepoint with a loop-carried timepoint, +// requiring the coverage analysis to track timepoints across scope boundaries. + +// CHECK-LABEL: @cross_scope_while_await_parent +// CHECK-SAME: (%[[PARENT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: index) +util.func private @cross_scope_while_await_parent(%parent_tp: !stream.timepoint, %size: index, %limit: index) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%parent_tp) => !stream.resource{%size} => !stream.timepoint + + // While loop with parent timepoint in condition and body. + // CHECK: %[[WHILE_RESULT:.+]]:2 = scf.while + %while_result:2 = scf.while (%iter = %c0, %tp = %alloca_tp) : (index, !stream.timepoint) -> (index, !stream.timepoint) { + %cond = arith.cmpi slt, %iter, %limit : index + scf.condition(%cond) %iter, %tp : index, !stream.timepoint + } do { + ^bb0(%iter: index, %tp: !stream.timepoint): + // Body joins parent_tp and loop-carried tp. + // CHECK: %[[JOINED:.+]] = stream.timepoint.join max(%[[PARENT_TP]], %{{.+}}) + %joined_tp = stream.timepoint.join max(%parent_tp, %tp) => !stream.timepoint + // CHECK: stream.test.timeline_op await(%[[JOINED]]) => with(%[[RESOURCE]]) : (!stream.resource{%[[SIZE]]}) + %cmd_tp = stream.test.timeline_op await(%joined_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + %next_iter = arith.addi %iter, %c1 : index + scf.yield %next_iter, %cmd_tp : index, !stream.timepoint + } + + // Pass creates a final join combining alloca and while result. + // CHECK: %{{.+}} = stream.timepoint.join max(%[[ALLOCA_TP]], %[[WHILE_RESULT]]#1) + // CHECK: %[[DEALLOCA:.+]] = stream.resource.dealloca origin await(%{{.+}}) => %[[RESOURCE]] + // CHECK: util.return %[[DEALLOCA]] + util.return %while_result#1 : !stream.timepoint +} + +// ----- + +// Tests the conservative fallback when scf.for has NO timepoint result. +// When the SCF op doesn't yield a timepoint, we cannot track resource lifetimes +// through it, so captured resources are marked indeterminate (no deallocation). +// This tests the "return false" path in analyzeForLoop when getOrJoinTimepointResults +// returns nullopt. + +// CHECK-LABEL: @for_no_timepoint_result_conservative +// CHECK-SAME: (%[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @for_no_timepoint_result_conservative(%input_tp: !stream.timepoint, %size: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // This loop yields ONLY an index, no timepoint. + // The pass cannot track resource lifetimes through this loop. + // CHECK: %[[LOOP_RESULT:.+]] = scf.for + %loop_result = scf.for %i = %c0 to %c10 step %c1 iter_args(%sum = %c0) -> index { + // Use the captured resource inside the loop. + // CHECK: stream.test.timeline_op await(%[[ALLOCA_TP]]) => with(%[[RESOURCE]]) : (!stream.resource{%[[SIZE]]}) + %cmd_tp = stream.test.timeline_op await(%alloca_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + %next_sum = arith.addi %sum, %c1 : index + scf.yield %next_sum : index + } + + // The captured resource should NOT have a deallocation inserted because + // the pass could not analyze it (no timepoint result from loop). + // CHECK-NOT: stream.resource.dealloca + // CHECK: util.return %[[LOOP_RESULT]] + util.return %loop_result : index +} + +// ----- + +// Tests the conservative fallback when scf.if has NO timepoint result. +// Similar to the for loop case - when the if doesn't yield a timepoint, +// captured resources are marked indeterminate. + +// CHECK-LABEL: @if_no_timepoint_result_conservative +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index) +util.func private @if_no_timepoint_result_conservative(%cond: i1, %input_tp: !stream.timepoint, %size: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + // CHECK: %[[RESOURCE:.+]], %[[ALLOCA_TP:.+]] = stream.resource.alloca uninitialized await(%[[INPUT_TP]]) => !stream.resource{%[[SIZE]]} + %resource, %alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // This if yields ONLY an index, no timepoint. + // CHECK: %[[IF_RESULT:.+]] = scf.if + %if_result = scf.if %cond -> index { + // CHECK: stream.test.timeline_op await(%[[ALLOCA_TP]]) => with(%[[RESOURCE]]) : (!stream.resource{%[[SIZE]]}) + %then_tp = stream.test.timeline_op await(%alloca_tp) => + with(%resource) : (!stream.resource{%size}) -> () => !stream.timepoint + scf.yield %c1 : index + } else { + scf.yield %c0 : index + } + + // The captured resource should NOT have a deallocation inserted. + // CHECK-NOT: stream.resource.dealloca + // CHECK: util.return %[[IF_RESULT]] + util.return %if_result : index +} + +// ----- + +// Tests that a resource allocated inside a loop, yielded out, but NOT returned +// from the function is still properly deallocated. This is the "yielded then +// dropped" scenario. + +// CHECK-LABEL: @loop_yielded_then_dropped +// CHECK-SAME: (%[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: !stream.resource) +util.func private @loop_yielded_then_dropped(%input_tp: !stream.timepoint, %size: index, %init_resource: !stream.resource) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + + // CHECK: %[[LOOP_RESULT:.+]]:2 = scf.for + %loop_resource, %loop_tp = scf.for %i = %c0 to %c10 step %c1 iter_args(%arg_res = %init_resource, %arg_tp = %input_tp) -> (!stream.resource, !stream.timepoint) { + // Allocate inside the loop (replacing the iter_arg resource each iteration). + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%arg_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op await(%[[LOCAL_ALLOCA_TP]]) + %cmd_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // Resource is yielded out of the loop (new resource replaces arg_res). + scf.yield %local_resource, %cmd_tp : !stream.resource, !stream.timepoint + } + + // The yielded resource is NOT returned - it's dropped here. + // The pass MUST insert a deallocation for %loop_resource after the loop. + // CHECK: stream.resource.dealloca {{.*}}await(%[[LOOP_RESULT]]#1) => %[[LOOP_RESULT]]#0 + // CHECK: util.return + util.return %loop_tp : !stream.timepoint +} + +// ----- + +// Tests that a resource allocated inside an if-branch, yielded out, but NOT +// returned from the function is still properly deallocated. + +// CHECK-LABEL: @if_yielded_then_dropped +// CHECK-SAME: ({{.+}}: i1, %[[INPUT_TP:.+]]: !stream.timepoint, %[[SIZE:.+]]: index, {{.+}}: !stream.resource) +util.func private @if_yielded_then_dropped(%cond: i1, %input_tp: !stream.timepoint, %size: index, %else_resource: !stream.resource) -> !stream.timepoint { + // CHECK: %[[IF_RESULT:.+]]:2 = scf.if + %if_resource, %if_tp = scf.if %cond -> (!stream.resource, !stream.timepoint) { + // Allocate inside the then-branch. + // CHECK: %[[LOCAL_RESOURCE:.+]], %[[LOCAL_ALLOCA_TP:.+]] = stream.resource.alloca + %local_resource, %local_alloca_tp = stream.resource.alloca uninitialized await(%input_tp) => !stream.resource{%size} => !stream.timepoint + + // CHECK: %[[CMD_TP:.+]] = stream.test.timeline_op await(%[[LOCAL_ALLOCA_TP]]) + %cmd_tp = stream.test.timeline_op await(%local_alloca_tp) => + with(%local_resource) : (!stream.resource{%size}) -> () => !stream.timepoint + + // Resource is yielded out of the if. + scf.yield %local_resource, %cmd_tp : !stream.resource, !stream.timepoint + } else { + // Else-branch yields a different resource (defined outside, so indeterminate). + scf.yield %else_resource, %input_tp : !stream.resource, !stream.timepoint + } + + // The yielded resource is NOT returned - it's dropped here. + // The pass MUST insert a deallocation for %if_resource after the if. + // CHECK: stream.resource.dealloca {{.*}}await(%[[IF_RESULT]]#1) => %[[IF_RESULT]]#0 + // CHECK: util.return + util.return %if_tp : !stream.timepoint +}