Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2352,7 +2352,7 @@ struct OperationConverter {
LogicalResult legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping);
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Legalize an operation result that was marked as "erased".
LogicalResult
Expand Down Expand Up @@ -2454,10 +2454,12 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {

LogicalResult
OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
std::optional<DenseMap<Value, SmallVector<Value>>> inverseMapping;
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)) ||
failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl)))
return failure();
DenseMap<Value, SmallVector<Value>> inverseMapping =
rewriterImpl.mapping.getInverse();
if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
inverseMapping)))
return failure();

Expand All @@ -2483,15 +2485,11 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
if (result.getType() == newValue.getType())
continue;

// Compute the inverse mapping only if it is really needed.
if (!inverseMapping)
inverseMapping = rewriterImpl.mapping.getInverse();

// Legalize this result.
rewriter.setInsertionPoint(op);
if (failed(legalizeChangedResultType(
op, result, newValue, opReplacement->getConverter(), rewriter,
rewriterImpl, *inverseMapping)))
rewriterImpl, inverseMapping)))
return failure();
}
}
Expand All @@ -2503,6 +2501,8 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
ConversionPatternRewriterImpl &rewriterImpl) {
// Functor used to check if all users of a value will be dead after
// conversion.
// TODO: This should probably query the inverse mapping, same as in
Copy link
Member Author

@matthias-springer matthias-springer Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is an unrelated bug here, will investigate separately, but putting a TODO.

// `legalizeChangedResultType`.
auto findLiveUser = [&](Value val) {
auto liveUserIt = llvm::find_if_not(val.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
Expand Down Expand Up @@ -2796,20 +2796,18 @@ static LogicalResult legalizeUnresolvedMaterialization(
LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
ConversionPatternRewriter &rewriter,
ConversionPatternRewriterImpl &rewriterImpl,
std::optional<DenseMap<Value, SmallVector<Value>>> &inverseMapping) {
inverseMapping = rewriterImpl.mapping.getInverse();

DenseMap<Value, SmallVector<Value>> &inverseMapping) {
// As an initial step, compute all of the inserted materializations that we
// expect to persist beyond the conversion process.
DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl,
*inverseMapping, necessaryMaterializations);
inverseMapping, necessaryMaterializations);

// Once computed, legalize any necessary materializations.
for (auto *mat : necessaryMaterializations) {
if (failed(legalizeUnresolvedMaterialization(
*mat, materializationOps, rewriter, rewriterImpl, *inverseMapping)))
*mat, materializationOps, rewriter, rewriterImpl, inverseMapping)))
return failure();
}
return success();
Expand Down