diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp index 005c5e0daed3..9e5377c5e673 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/AutomaticReferenceCounting.cpp @@ -306,31 +306,40 @@ struct LastUseSet { } }; +// Returns the timepoints sorted by their order in the block (textual order). +// All timepoints must be in the same block. +static SmallVector getSortedTimepointsInBlock(TimepointSet &timepoints) { + auto sorted = llvm::to_vector_of(timepoints); + llvm::sort(sorted, [](Value a, Value b) { + Operation *opA = a.getDefiningOp(); + Operation *opB = b.getDefiningOp(); + if (!opA && !opB) { + // Both are block arguments, compare by argument number. + return cast(a).getArgNumber() < + cast(b).getArgNumber(); + } + if (!opA) { + return true; // Block argument comes before operation. + } + if (!opB) { + return false; // Operation comes before block argument. + } + return opA->isBeforeInBlock(opB); + }); + return sorted; +} + // Returns the last defined SSA value in the block in |timepoints| (textual // order within the block). All timepoints must be in the same block. static Value getLastTimepointInBlock(TimepointSet &timepoints) { if (timepoints.empty()) { return nullptr; - } else if (timepoints.size() == 1) { - return *timepoints.begin(); } - Value lastTimepoint; - for (auto timepoint : timepoints) { - if (!lastTimepoint) { - lastTimepoint = timepoint; - } else { - auto *timepointOp = timepoint.getDefiningOp(); - auto *lastTimepointOp = lastTimepoint.getDefiningOp(); - if (!timepointOp) { - continue; // block arg - } else if (!lastTimepointOp) { - lastTimepoint = timepoint; // last found was a block arg, this isn't - } else if (lastTimepointOp->isBeforeInBlock(timepointOp)) { - lastTimepoint = timepoint; - } - } + if (timepoints.size() == 1) { + return *timepoints.begin(); } - return lastTimepoint; + SmallVector sorted = getSortedTimepointsInBlock(timepoints); + return sorted.back(); } // Returns a FusedLoc with the location of all |timepoints| and the base |loc|. @@ -595,8 +604,7 @@ static void insertDeallocations(LastUseSet &lastUseSet, AsmState *asmState, auto joinOp = IREE::Stream::TimepointJoinOp::create( builder, timepointsLoc, builder.getType(), - llvm::map_to_vector(timepoints, - [](Value timepoint) { return timepoint; })); + getSortedTimepointsInBlock(timepoints)); auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create( builder, timepointsLoc, builder.getType(), resource, diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index e19758e4d9bb..007552e38259 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -98,19 +98,25 @@ class HoistIntoGlobalsPass file.close(); } - // Maps original values to newly materialized values. - HoistedValueMap hoistedMap; - // Walk all operations in the program and hoist any escapes from // const-expr values into globals. Note that we must walk the const-exprs // in topological order so that corresponding initializers will be created // in order without depending on globals that have not been initialized // yet. + OpBuilder builder(&getContext()); for (auto funcOp : getOperation().getOps()) { // Ignore initializers. if (isa(funcOp.getOperation())) { continue; } + + // Maps original values to newly materialized globals (per-function). + HoistedValueMap hoistedMap; + + // Operation order for deterministic sorting (per-function). + llvm::DenseMap opOrder; + unsigned orderIdx = 0; + auto walkRes = funcOp.walk([&](Operation *iterOp) { // We only want to look at const-expr ops (non roots) since they may // have interesting escapes. Early exit here for efficiency. @@ -118,6 +124,10 @@ class HoistIntoGlobalsPass if (!iterInfo) { return WalkResult::advance(); } + + // Record operation order for deterministic sorting. Since we walk in + // PreOrder, producers are visited before their users. + opOrder[iterOp] = orderIdx++; for (Value constExprResult : iterOp->getResults()) { auto *resultInfo = constExprs.lookup(constExprResult); assert(resultInfo && "must have const-expr info"); @@ -126,7 +136,7 @@ class HoistIntoGlobalsPass continue; } if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols, - constExprs))) { + constExprs, opOrder))) { return WalkResult::interrupt(); } } @@ -135,35 +145,42 @@ class HoistIntoGlobalsPass if (walkRes.wasInterrupted()) { return signalPassFailure(); } - } - // Apply any remaining RAUW cleanups. We have to do these at the cleanup - // phase since modifying the source program can invalidate the analysis. - // Up to this point, we have only been cloning. - OpBuilder builder(&getContext()); - for (auto [originalValue, globalOp] : hoistedMap) { - builder.setInsertionPointAfterValue(originalValue); - auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); - if (!originalValue.getDefiningOp() - ->getParentOfType()) { - loadOp.setGlobalImmutable(true); - } - Value loadedValue = loadOp.getLoadedGlobalValue(); - // Call user hook to cast back to the original type. - if (auto hoistableType = dyn_cast( - originalValue.getType())) { - loadedValue = hoistableType.decodeStorageType( - builder, loadedValue.getLoc(), originalValue.getType(), - loadedValue); - } - if (loadedValue.getType() != originalValue.getType()) { - getOperation().emitError() - << "Unresolved conflict between casted global of type " - << loadedValue.getType() << " and original type " - << originalValue.getType(); - return signalPassFailure(); + // Apply RAUW cleanups for this function. We do this after cloning to + // avoid invalidating the analysis during the walk. + // Sort the hoisted values by program order for deterministic output. + using HoistedValue = std::pair; + auto sortedHoisted = llvm::to_vector_of(hoistedMap); + llvm::sort(sortedHoisted, + [&opOrder](const HoistedValue &lhs, const HoistedValue &rhs) { + return opOrder[lhs.first.getDefiningOp()] < + opOrder[rhs.first.getDefiningOp()]; + }); + + for (auto [originalValue, globalOp] : sortedHoisted) { + builder.setInsertionPointAfterValue(originalValue); + auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder); + if (!originalValue.getDefiningOp() + ->getParentOfType()) { + loadOp.setGlobalImmutable(true); + } + Value loadedValue = loadOp.getLoadedGlobalValue(); + // Call user hook to cast back to the original type. + if (auto hoistableType = dyn_cast( + originalValue.getType())) { + loadedValue = hoistableType.decodeStorageType( + builder, loadedValue.getLoc(), originalValue.getType(), + loadedValue); + } + if (loadedValue.getType() != originalValue.getType()) { + getOperation().emitError() + << "Unresolved conflict between casted global of type " + << loadedValue.getType() << " and original type " + << originalValue.getType(); + return signalPassFailure(); + } + originalValue.replaceAllUsesWith(loadedValue); } - originalValue.replaceAllUsesWith(loadedValue); } cleanupDeadOps(constExprs); } @@ -177,9 +194,11 @@ class HoistIntoGlobalsPass return op; } - LogicalResult hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap, - SymbolTable &moduleSymbols, - const ConstExprAnalysis &constExprs) { + LogicalResult + hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap, + SymbolTable &moduleSymbols, + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { IREE::Util::GlobalOp existingGlobal = hoistedMap.lookup(originalValue); if (existingGlobal) { return success(); @@ -202,7 +221,7 @@ class HoistIntoGlobalsPass if (failed(cloneConstExprInto(initializerOp.getLoc(), moduleBuilder, initializerBuilder, originalValue, dialectAttrs, hoistedMap, moduleSymbols, - constExprs))) { + constExprs, opOrder))) { return failure(); } @@ -218,7 +237,8 @@ class HoistIntoGlobalsPass cloneProducerTreeInto(OpBuilder &initializerBuilder, const ConstExprAnalysis::ConstValueInfo *producerInfo, HoistedValueMap &hoistedMap, IRMapping &cloneMapping, - const ConstExprAnalysis &constExprs) { + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { if (cloneMapping.contains(producerInfo->constValue)) { return; } @@ -243,10 +263,20 @@ class HoistIntoGlobalsPass return; } - // Materialize all producers recursively. - for (auto *producerInfo : producerInfo->producers) { - cloneProducerTreeInto(initializerBuilder, producerInfo, hoistedMap, - cloneMapping, constExprs); + // Materialize all producers recursively. Sort producers by their program + // order for deterministic output. + auto sortedProducers = + llvm::to_vector_of( + producerInfo->producers); + llvm::sort(sortedProducers, + [&opOrder](ConstExprAnalysis::ConstValueInfo *lhs, + ConstExprAnalysis::ConstValueInfo *rhs) { + return opOrder.lookup(lhs->constValue.getDefiningOp()) < + opOrder.lookup(rhs->constValue.getDefiningOp()); + }); + for (ConstExprAnalysis::ConstValueInfo *prodInfo : sortedProducers) { + cloneProducerTreeInto(initializerBuilder, prodInfo, hoistedMap, + cloneMapping, constExprs, opOrder); } // And clone the requested op. @@ -264,13 +294,13 @@ class HoistIntoGlobalsPass // Clones the const expr tree rooted at `constExprValue` into the given // initializer, noting any new hoisted value mappings that result. At // a minimum, a mapping will be created for the requested value. - LogicalResult cloneConstExprInto(Location loc, OpBuilder &moduleBuilder, - OpBuilder &initializerBuilder, - Value constExprValue, - NamedAttrList dialectAttrs, - HoistedValueMap &hoistedMap, - SymbolTable &moduleSymbols, - const ConstExprAnalysis &constExprs) { + LogicalResult + cloneConstExprInto(Location loc, OpBuilder &moduleBuilder, + OpBuilder &initializerBuilder, Value constExprValue, + NamedAttrList dialectAttrs, HoistedValueMap &hoistedMap, + SymbolTable &moduleSymbols, + const ConstExprAnalysis &constExprs, + const llvm::DenseMap &opOrder) { // Do a depth first traversal of the producers, emitting them in a valid // def-use order. Operation *rootOp = constExprValue.getDefiningOp(); @@ -281,7 +311,7 @@ class HoistIntoGlobalsPass // Clone the whole tree as needed. IRMapping cloneMapping; cloneProducerTreeInto(initializerBuilder, rootInfo, hoistedMap, - cloneMapping, constExprs); + cloneMapping, constExprs, opOrder); // And for each result, create a global and store into it. for (Value origResult : rootOp->getResults()) {