Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,40 @@ struct LastUseSet {
}
};

// Returns the timepoints sorted by their order in the block (textual order).
// All timepoints must be in the same block.
static SmallVector<Value> getSortedTimepointsInBlock(TimepointSet &timepoints) {
auto sorted = llvm::to_vector_of<Value>(timepoints);
llvm::sort(sorted, [](Value a, Value b) {
Operation *opA = a.getDefiningOp();
Operation *opB = b.getDefiningOp();
if (!opA && !opB) {
// Both are block arguments, compare by argument number.
return cast<BlockArgument>(a).getArgNumber() <
cast<BlockArgument>(b).getArgNumber();
}
if (!opA) {
return true; // Block argument comes before operation.
}
if (!opB) {
return false; // Operation comes before block argument.
}
return opA->isBeforeInBlock(opB);
});
return sorted;
}

// Returns the last defined SSA value in the block in |timepoints| (textual
// order within the block). All timepoints must be in the same block.
static Value getLastTimepointInBlock(TimepointSet &timepoints) {
if (timepoints.empty()) {
return nullptr;
} else if (timepoints.size() == 1) {
return *timepoints.begin();
}
Value lastTimepoint;
for (auto timepoint : timepoints) {
if (!lastTimepoint) {
lastTimepoint = timepoint;
} else {
auto *timepointOp = timepoint.getDefiningOp();
auto *lastTimepointOp = lastTimepoint.getDefiningOp();
if (!timepointOp) {
continue; // block arg
} else if (!lastTimepointOp) {
lastTimepoint = timepoint; // last found was a block arg, this isn't
} else if (lastTimepointOp->isBeforeInBlock(timepointOp)) {
lastTimepoint = timepoint;
}
}
if (timepoints.size() == 1) {
return *timepoints.begin();
}
return lastTimepoint;
SmallVector<Value> sorted = getSortedTimepointsInBlock(timepoints);
return sorted.back();
}

// Returns a FusedLoc with the location of all |timepoints| and the base |loc|.
Expand Down Expand Up @@ -595,8 +604,7 @@ static void insertDeallocations(LastUseSet &lastUseSet, AsmState *asmState,
auto joinOp = IREE::Stream::TimepointJoinOp::create(
builder, timepointsLoc,
builder.getType<IREE::Stream::TimepointType>(),
llvm::map_to_vector(timepoints,
[](Value timepoint) { return timepoint; }));
getSortedTimepointsInBlock(timepoints));
auto deallocaOp = IREE::Stream::ResourceDeallocaOp::create(
builder, timepointsLoc,
builder.getType<IREE::Stream::TimepointType>(), resource,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,26 +98,36 @@ class HoistIntoGlobalsPass
file.close();
}

// Maps original values to newly materialized values.
HoistedValueMap hoistedMap;

// Walk all operations in the program and hoist any escapes from
// const-expr values into globals. Note that we must walk the const-exprs
// in topological order so that corresponding initializers will be created
// in order without depending on globals that have not been initialized
// yet.
OpBuilder builder(&getContext());
for (auto funcOp : getOperation().getOps<FunctionOpInterface>()) {
// Ignore initializers.
if (isa<IREE::Util::InitializerOpInterface>(funcOp.getOperation())) {
continue;
}

// Maps original values to newly materialized globals (per-function).
HoistedValueMap hoistedMap;

// Operation order for deterministic sorting (per-function).
llvm::DenseMap<Operation *, unsigned> opOrder;
unsigned orderIdx = 0;

auto walkRes = funcOp.walk<WalkOrder::PreOrder>([&](Operation *iterOp) {
// We only want to look at const-expr ops (non roots) since they may
// have interesting escapes. Early exit here for efficiency.
auto *iterInfo = constExprs.lookup(iterOp);
if (!iterInfo) {
return WalkResult::advance();
}

// Record operation order for deterministic sorting. Since we walk in
// PreOrder, producers are visited before their users.
opOrder[iterOp] = orderIdx++;
for (Value constExprResult : iterOp->getResults()) {
auto *resultInfo = constExprs.lookup(constExprResult);
assert(resultInfo && "must have const-expr info");
Expand All @@ -126,7 +136,7 @@ class HoistIntoGlobalsPass
continue;
}
if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols,
constExprs))) {
constExprs, opOrder))) {
return WalkResult::interrupt();
}
}
Expand All @@ -135,35 +145,42 @@ class HoistIntoGlobalsPass
if (walkRes.wasInterrupted()) {
return signalPassFailure();
}
}

// Apply any remaining RAUW cleanups. We have to do these at the cleanup
// phase since modifying the source program can invalidate the analysis.
// Up to this point, we have only been cloning.
OpBuilder builder(&getContext());
for (auto [originalValue, globalOp] : hoistedMap) {
builder.setInsertionPointAfterValue(originalValue);
auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder);
if (!originalValue.getDefiningOp()
->getParentOfType<IREE::Util::InitializerOpInterface>()) {
loadOp.setGlobalImmutable(true);
}
Value loadedValue = loadOp.getLoadedGlobalValue();
// Call user hook to cast back to the original type.
if (auto hoistableType = dyn_cast<IREE::Util::HoistableTypeInterface>(
originalValue.getType())) {
loadedValue = hoistableType.decodeStorageType(
builder, loadedValue.getLoc(), originalValue.getType(),
loadedValue);
}
if (loadedValue.getType() != originalValue.getType()) {
getOperation().emitError()
<< "Unresolved conflict between casted global of type "
<< loadedValue.getType() << " and original type "
<< originalValue.getType();
return signalPassFailure();
// Apply RAUW cleanups for this function. We do this after cloning to
// avoid invalidating the analysis during the walk.
// Sort the hoisted values by program order for deterministic output.
using HoistedValue = std::pair<Value, GlobalOp>;
auto sortedHoisted = llvm::to_vector_of<HoistedValue>(hoistedMap);
llvm::sort(sortedHoisted,
[&opOrder](const HoistedValue &lhs, const HoistedValue &rhs) {
return opOrder[lhs.first.getDefiningOp()] <
opOrder[rhs.first.getDefiningOp()];
});

for (auto [originalValue, globalOp] : sortedHoisted) {
builder.setInsertionPointAfterValue(originalValue);
auto loadOp = globalOp.createLoadOp(globalOp->getLoc(), builder);
if (!originalValue.getDefiningOp()
->getParentOfType<IREE::Util::InitializerOpInterface>()) {
loadOp.setGlobalImmutable(true);
}
Value loadedValue = loadOp.getLoadedGlobalValue();
// Call user hook to cast back to the original type.
if (auto hoistableType = dyn_cast<IREE::Util::HoistableTypeInterface>(
originalValue.getType())) {
loadedValue = hoistableType.decodeStorageType(
builder, loadedValue.getLoc(), originalValue.getType(),
loadedValue);
}
if (loadedValue.getType() != originalValue.getType()) {
getOperation().emitError()
<< "Unresolved conflict between casted global of type "
<< loadedValue.getType() << " and original type "
<< originalValue.getType();
return signalPassFailure();
}
originalValue.replaceAllUsesWith(loadedValue);
}
originalValue.replaceAllUsesWith(loadedValue);
}
cleanupDeadOps(constExprs);
}
Expand All @@ -177,9 +194,11 @@ class HoistIntoGlobalsPass
return op;
}

LogicalResult hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap,
SymbolTable &moduleSymbols,
const ConstExprAnalysis &constExprs) {
LogicalResult
hoistConstExpr(Value originalValue, HoistedValueMap &hoistedMap,
SymbolTable &moduleSymbols,
const ConstExprAnalysis &constExprs,
const llvm::DenseMap<Operation *, unsigned> &opOrder) {
IREE::Util::GlobalOp existingGlobal = hoistedMap.lookup(originalValue);
if (existingGlobal) {
return success();
Expand All @@ -202,7 +221,7 @@ class HoistIntoGlobalsPass
if (failed(cloneConstExprInto(initializerOp.getLoc(), moduleBuilder,
initializerBuilder, originalValue,
dialectAttrs, hoistedMap, moduleSymbols,
constExprs))) {
constExprs, opOrder))) {
return failure();
}

Expand All @@ -218,7 +237,8 @@ class HoistIntoGlobalsPass
cloneProducerTreeInto(OpBuilder &initializerBuilder,
const ConstExprAnalysis::ConstValueInfo *producerInfo,
HoistedValueMap &hoistedMap, IRMapping &cloneMapping,
const ConstExprAnalysis &constExprs) {
const ConstExprAnalysis &constExprs,
const llvm::DenseMap<Operation *, unsigned> &opOrder) {
if (cloneMapping.contains(producerInfo->constValue)) {
return;
}
Expand All @@ -243,10 +263,20 @@ class HoistIntoGlobalsPass
return;
}

// Materialize all producers recursively.
for (auto *producerInfo : producerInfo->producers) {
cloneProducerTreeInto(initializerBuilder, producerInfo, hoistedMap,
cloneMapping, constExprs);
// Materialize all producers recursively. Sort producers by their program
// order for deterministic output.
auto sortedProducers =
llvm::to_vector_of<ConstExprAnalysis::ConstValueInfo *>(
producerInfo->producers);
llvm::sort(sortedProducers,
[&opOrder](ConstExprAnalysis::ConstValueInfo *lhs,
ConstExprAnalysis::ConstValueInfo *rhs) {
return opOrder.lookup(lhs->constValue.getDefiningOp()) <
opOrder.lookup(rhs->constValue.getDefiningOp());
});
for (ConstExprAnalysis::ConstValueInfo *prodInfo : sortedProducers) {
cloneProducerTreeInto(initializerBuilder, prodInfo, hoistedMap,
cloneMapping, constExprs, opOrder);
}

// And clone the requested op.
Expand All @@ -264,13 +294,13 @@ class HoistIntoGlobalsPass
// Clones the const expr tree rooted at `constExprValue` into the given
// initializer, noting any new hoisted value mappings that result. At
// a minimum, a mapping will be created for the requested value.
LogicalResult cloneConstExprInto(Location loc, OpBuilder &moduleBuilder,
OpBuilder &initializerBuilder,
Value constExprValue,
NamedAttrList dialectAttrs,
HoistedValueMap &hoistedMap,
SymbolTable &moduleSymbols,
const ConstExprAnalysis &constExprs) {
LogicalResult
cloneConstExprInto(Location loc, OpBuilder &moduleBuilder,
OpBuilder &initializerBuilder, Value constExprValue,
NamedAttrList dialectAttrs, HoistedValueMap &hoistedMap,
SymbolTable &moduleSymbols,
const ConstExprAnalysis &constExprs,
const llvm::DenseMap<Operation *, unsigned> &opOrder) {
// Do a depth first traversal of the producers, emitting them in a valid
// def-use order.
Operation *rootOp = constExprValue.getDefiningOp();
Expand All @@ -281,7 +311,7 @@ class HoistIntoGlobalsPass
// Clone the whole tree as needed.
IRMapping cloneMapping;
cloneProducerTreeInto(initializerBuilder, rootInfo, hoistedMap,
cloneMapping, constExprs);
cloneMapping, constExprs, opOrder);

// And for each result, create a global and store into it.
for (Value origResult : rootOp->getResults()) {
Expand Down
Loading